gen.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. // Copyright 2022 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package main
  17. import (
  18. "bytes"
  19. "fmt"
  20. "go/format"
  21. "go/types"
  22. "sort"
  23. "github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
  24. )
  25. // buildContext keeps the data needed for make*Op.
  26. type buildContext struct {
  27. topType *types.Named // the type we're creating methods for
  28. encoderIface *types.Interface
  29. decoderIface *types.Interface
  30. rawValueType *types.Named
  31. typeToStructCache map[types.Type]*rlpstruct.Type
  32. }
  33. func newBuildContext(packageRLP *types.Package) *buildContext {
  34. enc := packageRLP.Scope().Lookup("Encoder").Type().Underlying()
  35. dec := packageRLP.Scope().Lookup("Decoder").Type().Underlying()
  36. rawv := packageRLP.Scope().Lookup("RawValue").Type()
  37. return &buildContext{
  38. typeToStructCache: make(map[types.Type]*rlpstruct.Type),
  39. encoderIface: enc.(*types.Interface),
  40. decoderIface: dec.(*types.Interface),
  41. rawValueType: rawv.(*types.Named),
  42. }
  43. }
  44. func (bctx *buildContext) isEncoder(typ types.Type) bool {
  45. return types.Implements(typ, bctx.encoderIface)
  46. }
  47. func (bctx *buildContext) isDecoder(typ types.Type) bool {
  48. return types.Implements(typ, bctx.decoderIface)
  49. }
  50. // typeToStructType converts typ to rlpstruct.Type.
  51. func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
  52. if prev := bctx.typeToStructCache[typ]; prev != nil {
  53. return prev // short-circuit for recursive types.
  54. }
  55. // Resolve named types to their underlying type, but keep the name.
  56. name := types.TypeString(typ, nil)
  57. for {
  58. utype := typ.Underlying()
  59. if utype == typ {
  60. break
  61. }
  62. typ = utype
  63. }
  64. // Create the type and store it in cache.
  65. t := &rlpstruct.Type{
  66. Name: name,
  67. Kind: typeReflectKind(typ),
  68. IsEncoder: bctx.isEncoder(typ),
  69. IsDecoder: bctx.isDecoder(typ),
  70. }
  71. bctx.typeToStructCache[typ] = t
  72. // Assign element type.
  73. switch typ.(type) {
  74. case *types.Array, *types.Slice, *types.Pointer:
  75. etype := typ.(interface{ Elem() types.Type }).Elem()
  76. t.Elem = bctx.typeToStructType(etype)
  77. }
  78. return t
  79. }
  80. // genContext is passed to the gen* methods of op when generating
  81. // the output code. It tracks packages to be imported by the output
  82. // file and assigns unique names of temporary variables.
  83. type genContext struct {
  84. inPackage *types.Package
  85. imports map[string]struct{}
  86. tempCounter int
  87. }
  88. func newGenContext(inPackage *types.Package) *genContext {
  89. return &genContext{
  90. inPackage: inPackage,
  91. imports: make(map[string]struct{}),
  92. }
  93. }
  94. func (ctx *genContext) temp() string {
  95. v := fmt.Sprintf("_tmp%d", ctx.tempCounter)
  96. ctx.tempCounter++
  97. return v
  98. }
  99. func (ctx *genContext) resetTemp() {
  100. ctx.tempCounter = 0
  101. }
  102. func (ctx *genContext) addImport(path string) {
  103. if path == ctx.inPackage.Path() {
  104. return // avoid importing the package that we're generating in.
  105. }
  106. // TODO: renaming?
  107. ctx.imports[path] = struct{}{}
  108. }
  109. // importsList returns all packages that need to be imported.
  110. func (ctx *genContext) importsList() []string {
  111. imp := make([]string, 0, len(ctx.imports))
  112. for k := range ctx.imports {
  113. imp = append(imp, k)
  114. }
  115. sort.Strings(imp)
  116. return imp
  117. }
  118. // qualify is the types.Qualifier used for printing types.
  119. func (ctx *genContext) qualify(pkg *types.Package) string {
  120. if pkg.Path() == ctx.inPackage.Path() {
  121. return ""
  122. }
  123. ctx.addImport(pkg.Path())
  124. // TODO: renaming?
  125. return pkg.Name()
  126. }
  127. type op interface {
  128. // genWrite creates the encoder. The generated code should write v,
  129. // which is any Go expression, to the rlp.EncoderBuffer 'w'.
  130. genWrite(ctx *genContext, v string) string
  131. // genDecode creates the decoder. The generated code should read
  132. // a value from the rlp.Stream 'dec' and store it to dst.
  133. genDecode(ctx *genContext) (string, string)
  134. }
  135. // basicOp handles basic types bool, uint*, string.
  136. type basicOp struct {
  137. typ types.Type
  138. writeMethod string // calle write the value
  139. writeArgType types.Type // parameter type of writeMethod
  140. decMethod string
  141. decResultType types.Type // return type of decMethod
  142. decUseBitSize bool // if true, result bit size is appended to decMethod
  143. }
  144. func (*buildContext) makeBasicOp(typ *types.Basic) (op, error) {
  145. op := basicOp{typ: typ}
  146. kind := typ.Kind()
  147. switch {
  148. case kind == types.Bool:
  149. op.writeMethod = "WriteBool"
  150. op.writeArgType = types.Typ[types.Bool]
  151. op.decMethod = "Bool"
  152. op.decResultType = types.Typ[types.Bool]
  153. case kind >= types.Uint8 && kind <= types.Uint64:
  154. op.writeMethod = "WriteUint64"
  155. op.writeArgType = types.Typ[types.Uint64]
  156. op.decMethod = "Uint"
  157. op.decResultType = typ
  158. op.decUseBitSize = true
  159. case kind == types.String:
  160. op.writeMethod = "WriteString"
  161. op.writeArgType = types.Typ[types.String]
  162. op.decMethod = "String"
  163. op.decResultType = types.Typ[types.String]
  164. default:
  165. return nil, fmt.Errorf("unhandled basic type: %v", typ)
  166. }
  167. return op, nil
  168. }
  169. func (*buildContext) makeByteSliceOp(typ *types.Slice) op {
  170. if !isByte(typ.Elem()) {
  171. panic("non-byte slice type in makeByteSliceOp")
  172. }
  173. bslice := types.NewSlice(types.Typ[types.Uint8])
  174. return basicOp{
  175. typ: typ,
  176. writeMethod: "WriteBytes",
  177. writeArgType: bslice,
  178. decMethod: "Bytes",
  179. decResultType: bslice,
  180. }
  181. }
  182. func (bctx *buildContext) makeRawValueOp() op {
  183. bslice := types.NewSlice(types.Typ[types.Uint8])
  184. return basicOp{
  185. typ: bctx.rawValueType,
  186. writeMethod: "Write",
  187. writeArgType: bslice,
  188. decMethod: "Raw",
  189. decResultType: bslice,
  190. }
  191. }
  192. func (op basicOp) writeNeedsConversion() bool {
  193. return !types.AssignableTo(op.typ, op.writeArgType)
  194. }
  195. func (op basicOp) decodeNeedsConversion() bool {
  196. return !types.AssignableTo(op.decResultType, op.typ)
  197. }
  198. func (op basicOp) genWrite(ctx *genContext, v string) string {
  199. if op.writeNeedsConversion() {
  200. v = fmt.Sprintf("%s(%s)", op.writeArgType, v)
  201. }
  202. return fmt.Sprintf("w.%s(%s)\n", op.writeMethod, v)
  203. }
  204. func (op basicOp) genDecode(ctx *genContext) (string, string) {
  205. var (
  206. resultV = ctx.temp()
  207. result = resultV
  208. method = op.decMethod
  209. )
  210. if op.decUseBitSize {
  211. // Note: For now, this only works for platform-independent integer
  212. // sizes. makeBasicOp forbids the platform-dependent types.
  213. var sizes types.StdSizes
  214. method = fmt.Sprintf("%s%d", op.decMethod, sizes.Sizeof(op.typ)*8)
  215. }
  216. // Call the decoder method.
  217. var b bytes.Buffer
  218. fmt.Fprintf(&b, "%s, err := dec.%s()\n", resultV, method)
  219. fmt.Fprintf(&b, "if err != nil { return err }\n")
  220. if op.decodeNeedsConversion() {
  221. conv := ctx.temp()
  222. fmt.Fprintf(&b, "%s := %s(%s)\n", conv, types.TypeString(op.typ, ctx.qualify), resultV)
  223. result = conv
  224. }
  225. return result, b.String()
  226. }
  227. // byteArrayOp handles [...]byte.
  228. type byteArrayOp struct {
  229. typ types.Type
  230. name types.Type // name != typ for named byte array types (e.g. common.Address)
  231. }
  232. func (bctx *buildContext) makeByteArrayOp(name *types.Named, typ *types.Array) byteArrayOp {
  233. nt := types.Type(name)
  234. if name == nil {
  235. nt = typ
  236. }
  237. return byteArrayOp{typ, nt}
  238. }
  239. func (op byteArrayOp) genWrite(ctx *genContext, v string) string {
  240. return fmt.Sprintf("w.WriteBytes(%s[:])\n", v)
  241. }
  242. func (op byteArrayOp) genDecode(ctx *genContext) (string, string) {
  243. var resultV = ctx.temp()
  244. var b bytes.Buffer
  245. fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(op.name, ctx.qualify))
  246. fmt.Fprintf(&b, "if err := dec.ReadBytes(%s[:]); err != nil { return err }\n", resultV)
  247. return resultV, b.String()
  248. }
  249. // bigIntNoPtrOp handles non-pointer big.Int.
  250. // This exists because big.Int has it's own decoder operation on rlp.Stream,
  251. // but the decode method returns *big.Int, so it needs to be dereferenced.
  252. type bigIntOp struct {
  253. pointer bool
  254. }
  255. func (op bigIntOp) genWrite(ctx *genContext, v string) string {
  256. var b bytes.Buffer
  257. fmt.Fprintf(&b, "if %s.Sign() == -1 {\n", v)
  258. fmt.Fprintf(&b, " return rlp.ErrNegativeBigInt\n")
  259. fmt.Fprintf(&b, "}\n")
  260. dst := v
  261. if !op.pointer {
  262. dst = "&" + v
  263. }
  264. fmt.Fprintf(&b, "w.WriteBigInt(%s)\n", dst)
  265. // Wrap with nil check.
  266. if op.pointer {
  267. code := b.String()
  268. b.Reset()
  269. fmt.Fprintf(&b, "if %s == nil {\n", v)
  270. fmt.Fprintf(&b, " w.Write(rlp.EmptyString)")
  271. fmt.Fprintf(&b, "} else {\n")
  272. fmt.Fprint(&b, code)
  273. fmt.Fprintf(&b, "}\n")
  274. }
  275. return b.String()
  276. }
  277. func (op bigIntOp) genDecode(ctx *genContext) (string, string) {
  278. var resultV = ctx.temp()
  279. var b bytes.Buffer
  280. fmt.Fprintf(&b, "%s, err := dec.BigInt()\n", resultV)
  281. fmt.Fprintf(&b, "if err != nil { return err }\n")
  282. result := resultV
  283. if !op.pointer {
  284. result = "(*" + resultV + ")"
  285. }
  286. return result, b.String()
  287. }
  288. // encoderDecoderOp handles rlp.Encoder and rlp.Decoder.
  289. // In order to be used with this, the type must implement both interfaces.
  290. // This restriction may be lifted in the future by creating separate ops for
  291. // encoding and decoding.
  292. type encoderDecoderOp struct {
  293. typ types.Type
  294. }
  295. func (op encoderDecoderOp) genWrite(ctx *genContext, v string) string {
  296. return fmt.Sprintf("if err := %s.EncodeRLP(w); err != nil { return err }\n", v)
  297. }
  298. func (op encoderDecoderOp) genDecode(ctx *genContext) (string, string) {
  299. // DecodeRLP must have pointer receiver, and this is verified in makeOp.
  300. etyp := op.typ.(*types.Pointer).Elem()
  301. var resultV = ctx.temp()
  302. var b bytes.Buffer
  303. fmt.Fprintf(&b, "%s := new(%s)\n", resultV, types.TypeString(etyp, ctx.qualify))
  304. fmt.Fprintf(&b, "if err := %s.DecodeRLP(dec); err != nil { return err }\n", resultV)
  305. return resultV, b.String()
  306. }
  307. // ptrOp handles pointer types.
  308. type ptrOp struct {
  309. elemTyp types.Type
  310. elem op
  311. nilOK bool
  312. nilValue rlpstruct.NilKind
  313. }
  314. func (bctx *buildContext) makePtrOp(elemTyp types.Type, tags rlpstruct.Tags) (op, error) {
  315. elemOp, err := bctx.makeOp(nil, elemTyp, rlpstruct.Tags{})
  316. if err != nil {
  317. return nil, err
  318. }
  319. op := ptrOp{elemTyp: elemTyp, elem: elemOp}
  320. // Determine nil value.
  321. if tags.NilOK {
  322. op.nilOK = true
  323. op.nilValue = tags.NilKind
  324. } else {
  325. styp := bctx.typeToStructType(elemTyp)
  326. op.nilValue = styp.DefaultNilValue()
  327. }
  328. return op, nil
  329. }
  330. func (op ptrOp) genWrite(ctx *genContext, v string) string {
  331. // Note: in writer functions, accesses to v are read-only, i.e. v is any Go
  332. // expression. To make all accesses work through the pointer, we substitute
  333. // v with (*v). This is required for most accesses including `v`, `call(v)`,
  334. // and `v[index]` on slices.
  335. //
  336. // For `v.field` and `v[:]` on arrays, the dereference operation is not required.
  337. var vv string
  338. _, isStruct := op.elem.(structOp)
  339. _, isByteArray := op.elem.(byteArrayOp)
  340. if isStruct || isByteArray {
  341. vv = v
  342. } else {
  343. vv = fmt.Sprintf("(*%s)", v)
  344. }
  345. var b bytes.Buffer
  346. fmt.Fprintf(&b, "if %s == nil {\n", v)
  347. fmt.Fprintf(&b, " w.Write([]byte{0x%X})\n", op.nilValue)
  348. fmt.Fprintf(&b, "} else {\n")
  349. fmt.Fprintf(&b, " %s", op.elem.genWrite(ctx, vv))
  350. fmt.Fprintf(&b, "}\n")
  351. return b.String()
  352. }
  353. func (op ptrOp) genDecode(ctx *genContext) (string, string) {
  354. result, code := op.elem.genDecode(ctx)
  355. if !op.nilOK {
  356. // If nil pointers are not allowed, we can just decode the element.
  357. return "&" + result, code
  358. }
  359. // nil is allowed, so check the kind and size first.
  360. // If size is zero and kind matches the nilKind of the type,
  361. // the value decodes as a nil pointer.
  362. var (
  363. resultV = ctx.temp()
  364. kindV = ctx.temp()
  365. sizeV = ctx.temp()
  366. wantKind string
  367. )
  368. if op.nilValue == rlpstruct.NilKindList {
  369. wantKind = "rlp.List"
  370. } else {
  371. wantKind = "rlp.String"
  372. }
  373. var b bytes.Buffer
  374. fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(types.NewPointer(op.elemTyp), ctx.qualify))
  375. fmt.Fprintf(&b, "if %s, %s, err := dec.Kind(); err != nil {\n", kindV, sizeV)
  376. fmt.Fprintf(&b, " return err\n")
  377. fmt.Fprintf(&b, "} else if %s != 0 || %s != %s {\n", sizeV, kindV, wantKind)
  378. fmt.Fprint(&b, code)
  379. fmt.Fprintf(&b, " %s = &%s\n", resultV, result)
  380. fmt.Fprintf(&b, "}\n")
  381. return resultV, b.String()
  382. }
  383. // structOp handles struct types.
  384. type structOp struct {
  385. named *types.Named
  386. typ *types.Struct
  387. fields []*structField
  388. optionalFields []*structField
  389. }
  390. type structField struct {
  391. name string
  392. typ types.Type
  393. elem op
  394. }
  395. func (bctx *buildContext) makeStructOp(named *types.Named, typ *types.Struct) (op, error) {
  396. // Convert fields to []rlpstruct.Field.
  397. var allStructFields []rlpstruct.Field
  398. for i := 0; i < typ.NumFields(); i++ {
  399. f := typ.Field(i)
  400. allStructFields = append(allStructFields, rlpstruct.Field{
  401. Name: f.Name(),
  402. Exported: f.Exported(),
  403. Index: i,
  404. Tag: typ.Tag(i),
  405. Type: *bctx.typeToStructType(f.Type()),
  406. })
  407. }
  408. // Filter/validate fields.
  409. fields, tags, err := rlpstruct.ProcessFields(allStructFields)
  410. if err != nil {
  411. return nil, err
  412. }
  413. // Create field ops.
  414. var op = structOp{named: named, typ: typ}
  415. for i, field := range fields {
  416. // Advanced struct tags are not supported yet.
  417. tag := tags[i]
  418. if err := checkUnsupportedTags(field.Name, tag); err != nil {
  419. return nil, err
  420. }
  421. typ := typ.Field(field.Index).Type()
  422. elem, err := bctx.makeOp(nil, typ, tags[i])
  423. if err != nil {
  424. return nil, fmt.Errorf("field %s: %v", field.Name, err)
  425. }
  426. f := &structField{name: field.Name, typ: typ, elem: elem}
  427. if tag.Optional {
  428. op.optionalFields = append(op.optionalFields, f)
  429. } else {
  430. op.fields = append(op.fields, f)
  431. }
  432. }
  433. return op, nil
  434. }
  435. func checkUnsupportedTags(field string, tag rlpstruct.Tags) error {
  436. if tag.Tail {
  437. return fmt.Errorf(`field %s has unsupported struct tag "tail"`, field)
  438. }
  439. return nil
  440. }
  441. func (op structOp) genWrite(ctx *genContext, v string) string {
  442. var b bytes.Buffer
  443. var listMarker = ctx.temp()
  444. fmt.Fprintf(&b, "%s := w.List()\n", listMarker)
  445. for _, field := range op.fields {
  446. selector := v + "." + field.name
  447. fmt.Fprint(&b, field.elem.genWrite(ctx, selector))
  448. }
  449. op.writeOptionalFields(&b, ctx, v)
  450. fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker)
  451. return b.String()
  452. }
  453. func (op structOp) writeOptionalFields(b *bytes.Buffer, ctx *genContext, v string) {
  454. if len(op.optionalFields) == 0 {
  455. return
  456. }
  457. // First check zero-ness of all optional fields.
  458. var zeroV = make([]string, len(op.optionalFields))
  459. for i, field := range op.optionalFields {
  460. selector := v + "." + field.name
  461. zeroV[i] = ctx.temp()
  462. fmt.Fprintf(b, "%s := %s\n", zeroV[i], nonZeroCheck(selector, field.typ, ctx.qualify))
  463. }
  464. // Now write the fields.
  465. for i, field := range op.optionalFields {
  466. selector := v + "." + field.name
  467. cond := ""
  468. for j := i; j < len(op.optionalFields); j++ {
  469. if j > i {
  470. cond += " || "
  471. }
  472. cond += zeroV[j]
  473. }
  474. fmt.Fprintf(b, "if %s {\n", cond)
  475. fmt.Fprint(b, field.elem.genWrite(ctx, selector))
  476. fmt.Fprintf(b, "}\n")
  477. }
  478. }
  479. func (op structOp) genDecode(ctx *genContext) (string, string) {
  480. // Get the string representation of the type.
  481. // Here, named types are handled separately because the output
  482. // would contain a copy of the struct definition otherwise.
  483. var typeName string
  484. if op.named != nil {
  485. typeName = types.TypeString(op.named, ctx.qualify)
  486. } else {
  487. typeName = types.TypeString(op.typ, ctx.qualify)
  488. }
  489. // Create struct object.
  490. var resultV = ctx.temp()
  491. var b bytes.Buffer
  492. fmt.Fprintf(&b, "var %s %s\n", resultV, typeName)
  493. // Decode fields.
  494. fmt.Fprintf(&b, "{\n")
  495. fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n")
  496. for _, field := range op.fields {
  497. result, code := field.elem.genDecode(ctx)
  498. fmt.Fprintf(&b, "// %s:\n", field.name)
  499. fmt.Fprint(&b, code)
  500. fmt.Fprintf(&b, "%s.%s = %s\n", resultV, field.name, result)
  501. }
  502. op.decodeOptionalFields(&b, ctx, resultV)
  503. fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n")
  504. fmt.Fprintf(&b, "}\n")
  505. return resultV, b.String()
  506. }
  507. func (op structOp) decodeOptionalFields(b *bytes.Buffer, ctx *genContext, resultV string) {
  508. var suffix bytes.Buffer
  509. for _, field := range op.optionalFields {
  510. result, code := field.elem.genDecode(ctx)
  511. fmt.Fprintf(b, "// %s:\n", field.name)
  512. fmt.Fprintf(b, "if dec.MoreDataInList() {\n")
  513. fmt.Fprint(b, code)
  514. fmt.Fprintf(b, "%s.%s = %s\n", resultV, field.name, result)
  515. fmt.Fprintf(&suffix, "}\n")
  516. }
  517. suffix.WriteTo(b)
  518. }
  519. // sliceOp handles slice types.
  520. type sliceOp struct {
  521. typ *types.Slice
  522. elemOp op
  523. }
  524. func (bctx *buildContext) makeSliceOp(typ *types.Slice) (op, error) {
  525. elemOp, err := bctx.makeOp(nil, typ.Elem(), rlpstruct.Tags{})
  526. if err != nil {
  527. return nil, err
  528. }
  529. return sliceOp{typ: typ, elemOp: elemOp}, nil
  530. }
  531. func (op sliceOp) genWrite(ctx *genContext, v string) string {
  532. var (
  533. listMarker = ctx.temp() // holds return value of w.List()
  534. iterElemV = ctx.temp() // iteration variable
  535. elemCode = op.elemOp.genWrite(ctx, iterElemV)
  536. )
  537. var b bytes.Buffer
  538. fmt.Fprintf(&b, "%s := w.List()\n", listMarker)
  539. fmt.Fprintf(&b, "for _, %s := range %s {\n", iterElemV, v)
  540. fmt.Fprint(&b, elemCode)
  541. fmt.Fprintf(&b, "}\n")
  542. fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker)
  543. return b.String()
  544. }
  545. func (op sliceOp) genDecode(ctx *genContext) (string, string) {
  546. var sliceV = ctx.temp() // holds the output slice
  547. elemResult, elemCode := op.elemOp.genDecode(ctx)
  548. var b bytes.Buffer
  549. fmt.Fprintf(&b, "var %s %s\n", sliceV, types.TypeString(op.typ, ctx.qualify))
  550. fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n")
  551. fmt.Fprintf(&b, "for dec.MoreDataInList() {\n")
  552. fmt.Fprintf(&b, " %s", elemCode)
  553. fmt.Fprintf(&b, " %s = append(%s, %s)\n", sliceV, sliceV, elemResult)
  554. fmt.Fprintf(&b, "}\n")
  555. fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n")
  556. return sliceV, b.String()
  557. }
  558. func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstruct.Tags) (op, error) {
  559. switch typ := typ.(type) {
  560. case *types.Named:
  561. if isBigInt(typ) {
  562. return bigIntOp{}, nil
  563. }
  564. if typ == bctx.rawValueType {
  565. return bctx.makeRawValueOp(), nil
  566. }
  567. if bctx.isDecoder(typ) {
  568. return nil, fmt.Errorf("type %v implements rlp.Decoder with non-pointer receiver", typ)
  569. }
  570. // TODO: same check for encoder?
  571. return bctx.makeOp(typ, typ.Underlying(), tags)
  572. case *types.Pointer:
  573. if isBigInt(typ.Elem()) {
  574. return bigIntOp{pointer: true}, nil
  575. }
  576. // Encoder/Decoder interfaces.
  577. if bctx.isEncoder(typ) {
  578. if bctx.isDecoder(typ) {
  579. return encoderDecoderOp{typ}, nil
  580. }
  581. return nil, fmt.Errorf("type %v implements rlp.Encoder but not rlp.Decoder", typ)
  582. }
  583. if bctx.isDecoder(typ) {
  584. return nil, fmt.Errorf("type %v implements rlp.Decoder but not rlp.Encoder", typ)
  585. }
  586. // Default pointer handling.
  587. return bctx.makePtrOp(typ.Elem(), tags)
  588. case *types.Basic:
  589. return bctx.makeBasicOp(typ)
  590. case *types.Struct:
  591. return bctx.makeStructOp(name, typ)
  592. case *types.Slice:
  593. etyp := typ.Elem()
  594. if isByte(etyp) && !bctx.isEncoder(etyp) {
  595. return bctx.makeByteSliceOp(typ), nil
  596. }
  597. return bctx.makeSliceOp(typ)
  598. case *types.Array:
  599. etyp := typ.Elem()
  600. if isByte(etyp) && !bctx.isEncoder(etyp) {
  601. return bctx.makeByteArrayOp(name, typ), nil
  602. }
  603. return nil, fmt.Errorf("unhandled array type: %v", typ)
  604. default:
  605. return nil, fmt.Errorf("unhandled type: %v", typ)
  606. }
  607. }
  608. // generateDecoder generates the DecodeRLP method on 'typ'.
  609. func generateDecoder(ctx *genContext, typ string, op op) []byte {
  610. ctx.resetTemp()
  611. ctx.addImport(pathOfPackageRLP)
  612. result, code := op.genDecode(ctx)
  613. var b bytes.Buffer
  614. fmt.Fprintf(&b, "func (obj *%s) DecodeRLP(dec *rlp.Stream) error {\n", typ)
  615. fmt.Fprint(&b, code)
  616. fmt.Fprintf(&b, " *obj = %s\n", result)
  617. fmt.Fprintf(&b, " return nil\n")
  618. fmt.Fprintf(&b, "}\n")
  619. return b.Bytes()
  620. }
  621. // generateEncoder generates the EncodeRLP method on 'typ'.
  622. func generateEncoder(ctx *genContext, typ string, op op) []byte {
  623. ctx.resetTemp()
  624. ctx.addImport("io")
  625. ctx.addImport(pathOfPackageRLP)
  626. var b bytes.Buffer
  627. fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ)
  628. fmt.Fprintf(&b, " w := rlp.NewEncoderBuffer(_w)\n")
  629. fmt.Fprint(&b, op.genWrite(ctx, "obj"))
  630. fmt.Fprintf(&b, " return w.Flush()\n")
  631. fmt.Fprintf(&b, "}\n")
  632. return b.Bytes()
  633. }
  634. func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]byte, error) {
  635. bctx.topType = typ
  636. pkg := typ.Obj().Pkg()
  637. op, err := bctx.makeOp(nil, typ, rlpstruct.Tags{})
  638. if err != nil {
  639. return nil, err
  640. }
  641. var (
  642. ctx = newGenContext(pkg)
  643. encSource []byte
  644. decSource []byte
  645. )
  646. if encoder {
  647. encSource = generateEncoder(ctx, typ.Obj().Name(), op)
  648. }
  649. if decoder {
  650. decSource = generateDecoder(ctx, typ.Obj().Name(), op)
  651. }
  652. var b bytes.Buffer
  653. fmt.Fprintf(&b, "package %s\n\n", pkg.Name())
  654. for _, imp := range ctx.importsList() {
  655. fmt.Fprintf(&b, "import %q\n", imp)
  656. }
  657. if encoder {
  658. fmt.Fprintln(&b)
  659. b.Write(encSource)
  660. }
  661. if decoder {
  662. fmt.Fprintln(&b)
  663. b.Write(decSource)
  664. }
  665. source := b.Bytes()
  666. // fmt.Println(string(source))
  667. return format.Source(source)
  668. }