nodestate_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. // Copyright 2020 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package nodestate
  17. import (
  18. "errors"
  19. "fmt"
  20. "reflect"
  21. "testing"
  22. "time"
  23. "github.com/ethereum/go-ethereum/common/mclock"
  24. "github.com/ethereum/go-ethereum/core/rawdb"
  25. "github.com/ethereum/go-ethereum/p2p/enode"
  26. "github.com/ethereum/go-ethereum/p2p/enr"
  27. "github.com/ethereum/go-ethereum/rlp"
  28. )
  29. func testSetup(flagPersist []bool, fieldType []reflect.Type) (*Setup, []Flags, []Field) {
  30. setup := &Setup{}
  31. flags := make([]Flags, len(flagPersist))
  32. for i, persist := range flagPersist {
  33. if persist {
  34. flags[i] = setup.NewPersistentFlag(fmt.Sprintf("flag-%d", i))
  35. } else {
  36. flags[i] = setup.NewFlag(fmt.Sprintf("flag-%d", i))
  37. }
  38. }
  39. fields := make([]Field, len(fieldType))
  40. for i, ftype := range fieldType {
  41. switch ftype {
  42. case reflect.TypeOf(uint64(0)):
  43. fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, uint64FieldEnc, uint64FieldDec)
  44. case reflect.TypeOf(""):
  45. fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, stringFieldEnc, stringFieldDec)
  46. default:
  47. fields[i] = setup.NewField(fmt.Sprintf("field-%d", i), ftype)
  48. }
  49. }
  50. return setup, flags, fields
  51. }
  52. func testNode(b byte) *enode.Node {
  53. r := &enr.Record{}
  54. r.SetSig(dummyIdentity{b}, []byte{42})
  55. n, _ := enode.New(dummyIdentity{b}, r)
  56. return n
  57. }
  58. func TestCallback(t *testing.T) {
  59. mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
  60. s, flags, _ := testSetup([]bool{false, false, false}, nil)
  61. ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
  62. set0 := make(chan struct{}, 1)
  63. set1 := make(chan struct{}, 1)
  64. set2 := make(chan struct{}, 1)
  65. ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { set0 <- struct{}{} })
  66. ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { set1 <- struct{}{} })
  67. ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags) { set2 <- struct{}{} })
  68. ns.Start()
  69. ns.SetState(testNode(1), flags[0], Flags{}, 0)
  70. ns.SetState(testNode(1), flags[1], Flags{}, time.Second)
  71. ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second)
  72. for i := 0; i < 3; i++ {
  73. select {
  74. case <-set0:
  75. case <-set1:
  76. case <-set2:
  77. case <-time.After(time.Second):
  78. t.Fatalf("failed to invoke callback")
  79. }
  80. }
  81. }
  82. func TestPersistentFlags(t *testing.T) {
  83. mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
  84. s, flags, _ := testSetup([]bool{true, true, true, false}, nil)
  85. ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
  86. saveNode := make(chan *nodeInfo, 5)
  87. ns.saveNodeHook = func(node *nodeInfo) {
  88. saveNode <- node
  89. }
  90. ns.Start()
  91. ns.SetState(testNode(1), flags[0], Flags{}, time.Second) // state with timeout should not be saved
  92. ns.SetState(testNode(2), flags[1], Flags{}, 0)
  93. ns.SetState(testNode(3), flags[2], Flags{}, 0)
  94. ns.SetState(testNode(4), flags[3], Flags{}, 0)
  95. ns.SetState(testNode(5), flags[0], Flags{}, 0)
  96. ns.Persist(testNode(5))
  97. select {
  98. case <-saveNode:
  99. case <-time.After(time.Second):
  100. t.Fatalf("Timeout")
  101. }
  102. ns.Stop()
  103. for i := 0; i < 2; i++ {
  104. select {
  105. case <-saveNode:
  106. case <-time.After(time.Second):
  107. t.Fatalf("Timeout")
  108. }
  109. }
  110. select {
  111. case <-saveNode:
  112. t.Fatalf("Unexpected saveNode")
  113. case <-time.After(time.Millisecond * 100):
  114. }
  115. }
  116. func TestSetField(t *testing.T) {
  117. mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
  118. s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf("")})
  119. ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
  120. saveNode := make(chan *nodeInfo, 1)
  121. ns.saveNodeHook = func(node *nodeInfo) {
  122. saveNode <- node
  123. }
  124. ns.Start()
  125. // Set field before setting state
  126. ns.SetField(testNode(1), fields[0], "hello world")
  127. field := ns.GetField(testNode(1), fields[0])
  128. if field == nil {
  129. t.Fatalf("Field should be set before setting states")
  130. }
  131. ns.SetField(testNode(1), fields[0], nil)
  132. field = ns.GetField(testNode(1), fields[0])
  133. if field != nil {
  134. t.Fatalf("Field should be unset")
  135. }
  136. // Set field after setting state
  137. ns.SetState(testNode(1), flags[0], Flags{}, 0)
  138. ns.SetField(testNode(1), fields[0], "hello world")
  139. field = ns.GetField(testNode(1), fields[0])
  140. if field == nil {
  141. t.Fatalf("Field should be set after setting states")
  142. }
  143. if err := ns.SetField(testNode(1), fields[0], 123); err == nil {
  144. t.Fatalf("Invalid field should be rejected")
  145. }
  146. // Dirty node should be written back
  147. ns.Stop()
  148. select {
  149. case <-saveNode:
  150. case <-time.After(time.Second):
  151. t.Fatalf("Timeout")
  152. }
  153. }
  154. func TestSetState(t *testing.T) {
  155. mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
  156. s, flags, _ := testSetup([]bool{false, false, false}, nil)
  157. ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
  158. type change struct{ old, new Flags }
  159. set := make(chan change, 1)
  160. ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags) {
  161. set <- change{
  162. old: oldState,
  163. new: newState,
  164. }
  165. })
  166. ns.Start()
  167. check := func(expectOld, expectNew Flags, expectChange bool) {
  168. if expectChange {
  169. select {
  170. case c := <-set:
  171. if !c.old.Equals(expectOld) {
  172. t.Fatalf("Old state mismatch")
  173. }
  174. if !c.new.Equals(expectNew) {
  175. t.Fatalf("New state mismatch")
  176. }
  177. case <-time.After(time.Second):
  178. }
  179. return
  180. }
  181. select {
  182. case <-set:
  183. t.Fatalf("Unexpected change")
  184. case <-time.After(time.Millisecond * 100):
  185. return
  186. }
  187. }
  188. ns.SetState(testNode(1), flags[0], Flags{}, 0)
  189. check(Flags{}, flags[0], true)
  190. ns.SetState(testNode(1), flags[1], Flags{}, 0)
  191. check(flags[0], flags[0].Or(flags[1]), true)
  192. ns.SetState(testNode(1), flags[2], Flags{}, 0)
  193. check(Flags{}, Flags{}, false)
  194. ns.SetState(testNode(1), Flags{}, flags[0], 0)
  195. check(flags[0].Or(flags[1]), flags[1], true)
  196. ns.SetState(testNode(1), Flags{}, flags[1], 0)
  197. check(flags[1], Flags{}, true)
  198. ns.SetState(testNode(1), Flags{}, flags[2], 0)
  199. check(Flags{}, Flags{}, false)
  200. ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second)
  201. check(Flags{}, flags[0].Or(flags[1]), true)
  202. clock.Run(time.Second)
  203. check(flags[0].Or(flags[1]), Flags{}, true)
  204. }
  205. func uint64FieldEnc(field interface{}) ([]byte, error) {
  206. if u, ok := field.(uint64); ok {
  207. enc, err := rlp.EncodeToBytes(&u)
  208. return enc, err
  209. }
  210. return nil, errors.New("invalid field type")
  211. }
  212. func uint64FieldDec(enc []byte) (interface{}, error) {
  213. var u uint64
  214. err := rlp.DecodeBytes(enc, &u)
  215. return u, err
  216. }
  217. func stringFieldEnc(field interface{}) ([]byte, error) {
  218. if s, ok := field.(string); ok {
  219. return []byte(s), nil
  220. }
  221. return nil, errors.New("invalid field type")
  222. }
  223. func stringFieldDec(enc []byte) (interface{}, error) {
  224. return string(enc), nil
  225. }
  226. func TestPersistentFields(t *testing.T) {
  227. mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
  228. s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0)), reflect.TypeOf("")})
  229. ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
  230. ns.Start()
  231. ns.SetState(testNode(1), flags[0], Flags{}, 0)
  232. ns.SetField(testNode(1), fields[0], uint64(100))
  233. ns.SetField(testNode(1), fields[1], "hello world")
  234. ns.Stop()
  235. ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
  236. ns2.Start()
  237. field0 := ns2.GetField(testNode(1), fields[0])
  238. if !reflect.DeepEqual(field0, uint64(100)) {
  239. t.Fatalf("Field changed")
  240. }
  241. field1 := ns2.GetField(testNode(1), fields[1])
  242. if !reflect.DeepEqual(field1, "hello world") {
  243. t.Fatalf("Field changed")
  244. }
  245. s.Version++
  246. ns3 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
  247. ns3.Start()
  248. if ns3.GetField(testNode(1), fields[0]) != nil {
  249. t.Fatalf("Old field version should have been discarded")
  250. }
  251. }
  252. func TestFieldSub(t *testing.T) {
  253. mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
  254. s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0))})
  255. ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
  256. var (
  257. lastState Flags
  258. lastOldValue, lastNewValue interface{}
  259. )
  260. ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) {
  261. lastState, lastOldValue, lastNewValue = state, oldValue, newValue
  262. })
  263. check := func(state Flags, oldValue, newValue interface{}) {
  264. if !lastState.Equals(state) || lastOldValue != oldValue || lastNewValue != newValue {
  265. t.Fatalf("Incorrect field sub callback (expected [%v %v %v], got [%v %v %v])", state, oldValue, newValue, lastState, lastOldValue, lastNewValue)
  266. }
  267. }
  268. ns.Start()
  269. ns.SetState(testNode(1), flags[0], Flags{}, 0)
  270. ns.SetField(testNode(1), fields[0], uint64(100))
  271. check(flags[0], nil, uint64(100))
  272. ns.Stop()
  273. check(s.OfflineFlag(), uint64(100), nil)
  274. ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
  275. ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) {
  276. lastState, lastOldValue, lastNewValue = state, oldValue, newValue
  277. })
  278. ns2.Start()
  279. check(s.OfflineFlag(), nil, uint64(100))
  280. ns2.SetState(testNode(1), Flags{}, flags[0], 0)
  281. ns2.SetField(testNode(1), fields[0], nil)
  282. check(Flags{}, uint64(100), nil)
  283. ns2.Stop()
  284. }
  285. func TestDuplicatedFlags(t *testing.T) {
  286. mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
  287. s, flags, _ := testSetup([]bool{true}, nil)
  288. ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
  289. type change struct{ old, new Flags }
  290. set := make(chan change, 1)
  291. ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) {
  292. set <- change{oldState, newState}
  293. })
  294. ns.Start()
  295. defer ns.Stop()
  296. check := func(expectOld, expectNew Flags, expectChange bool) {
  297. if expectChange {
  298. select {
  299. case c := <-set:
  300. if !c.old.Equals(expectOld) {
  301. t.Fatalf("Old state mismatch")
  302. }
  303. if !c.new.Equals(expectNew) {
  304. t.Fatalf("New state mismatch")
  305. }
  306. case <-time.After(time.Second):
  307. }
  308. return
  309. }
  310. select {
  311. case <-set:
  312. t.Fatalf("Unexpected change")
  313. case <-time.After(time.Millisecond * 100):
  314. return
  315. }
  316. }
  317. ns.SetState(testNode(1), flags[0], Flags{}, time.Second)
  318. check(Flags{}, flags[0], true)
  319. ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second) // extend the timeout to 2s
  320. check(Flags{}, flags[0], false)
  321. clock.Run(2 * time.Second)
  322. check(flags[0], Flags{}, true)
  323. }
  324. func TestCallbackOrder(t *testing.T) {
  325. mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
  326. s, flags, _ := testSetup([]bool{false, false, false, false}, nil)
  327. ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
  328. ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) {
  329. if newState.Equals(flags[0]) {
  330. ns.SetStateSub(n, flags[1], Flags{}, 0)
  331. ns.SetStateSub(n, flags[2], Flags{}, 0)
  332. }
  333. })
  334. ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) {
  335. if newState.Equals(flags[1]) {
  336. ns.SetStateSub(n, flags[3], Flags{}, 0)
  337. }
  338. })
  339. lastState := Flags{}
  340. ns.SubscribeState(MergeFlags(flags[1], flags[2], flags[3]), func(n *enode.Node, oldState, newState Flags) {
  341. if !oldState.Equals(lastState) {
  342. t.Fatalf("Wrong callback order")
  343. }
  344. lastState = newState
  345. })
  346. ns.Start()
  347. defer ns.Stop()
  348. ns.SetState(testNode(1), flags[0], Flags{}, 0)
  349. }