main.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. // Copyright 2022 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package main
  17. import (
  18. "bytes"
  19. "errors"
  20. "flag"
  21. "fmt"
  22. "go/types"
  23. "os"
  24. "golang.org/x/tools/go/packages"
  25. )
  26. const pathOfPackageRLP = "github.com/ethereum/go-ethereum/rlp"
  27. func main() {
  28. var (
  29. pkgdir = flag.String("dir", ".", "input package")
  30. output = flag.String("out", "-", "output file (default is stdout)")
  31. genEncoder = flag.Bool("encoder", true, "generate EncodeRLP?")
  32. genDecoder = flag.Bool("decoder", false, "generate DecodeRLP?")
  33. typename = flag.String("type", "", "type to generate methods for")
  34. )
  35. flag.Parse()
  36. cfg := Config{
  37. Dir: *pkgdir,
  38. Type: *typename,
  39. GenerateEncoder: *genEncoder,
  40. GenerateDecoder: *genDecoder,
  41. }
  42. code, err := cfg.process()
  43. if err != nil {
  44. fatal(err)
  45. }
  46. if *output == "-" {
  47. os.Stdout.Write(code)
  48. } else if err := os.WriteFile(*output, code, 0600); err != nil {
  49. fatal(err)
  50. }
  51. }
  52. func fatal(args ...interface{}) {
  53. fmt.Fprintln(os.Stderr, args...)
  54. os.Exit(1)
  55. }
  56. type Config struct {
  57. Dir string // input package directory
  58. Type string
  59. GenerateEncoder bool
  60. GenerateDecoder bool
  61. }
  62. // process generates the Go code.
  63. func (cfg *Config) process() (code []byte, err error) {
  64. // Load packages.
  65. pcfg := &packages.Config{
  66. Mode: packages.NeedName | packages.NeedTypes | packages.NeedImports | packages.NeedDeps,
  67. Dir: cfg.Dir,
  68. BuildFlags: []string{"-tags", "norlpgen"},
  69. }
  70. ps, err := packages.Load(pcfg, pathOfPackageRLP, ".")
  71. if err != nil {
  72. return nil, err
  73. }
  74. if len(ps) == 0 {
  75. return nil, fmt.Errorf("no Go package found in %s", cfg.Dir)
  76. }
  77. packages.PrintErrors(ps)
  78. // Find the packages that were loaded.
  79. var (
  80. pkg *types.Package
  81. packageRLP *types.Package
  82. )
  83. for _, p := range ps {
  84. if len(p.Errors) > 0 {
  85. return nil, fmt.Errorf("package %s has errors", p.PkgPath)
  86. }
  87. if p.PkgPath == pathOfPackageRLP {
  88. packageRLP = p.Types
  89. } else {
  90. pkg = p.Types
  91. }
  92. }
  93. bctx := newBuildContext(packageRLP)
  94. // Find the type and generate.
  95. typ, err := lookupStructType(pkg.Scope(), cfg.Type)
  96. if err != nil {
  97. return nil, fmt.Errorf("can't find %s in %s: %v", cfg.Type, pkg, err)
  98. }
  99. code, err = bctx.generate(typ, cfg.GenerateEncoder, cfg.GenerateDecoder)
  100. if err != nil {
  101. return nil, err
  102. }
  103. // Add build comments.
  104. // This is done here to avoid processing these lines with gofmt.
  105. var header bytes.Buffer
  106. fmt.Fprint(&header, "// Code generated by rlpgen. DO NOT EDIT.\n\n")
  107. fmt.Fprint(&header, "//go:build !norlpgen\n")
  108. fmt.Fprint(&header, "// +build !norlpgen\n\n")
  109. return append(header.Bytes(), code...), nil
  110. }
  111. func lookupStructType(scope *types.Scope, name string) (*types.Named, error) {
  112. typ, err := lookupType(scope, name)
  113. if err != nil {
  114. return nil, err
  115. }
  116. _, ok := typ.Underlying().(*types.Struct)
  117. if !ok {
  118. return nil, errors.New("not a struct type")
  119. }
  120. return typ, nil
  121. }
  122. func lookupType(scope *types.Scope, name string) (*types.Named, error) {
  123. obj := scope.Lookup(name)
  124. if obj == nil {
  125. return nil, errors.New("no such identifier")
  126. }
  127. typ, ok := obj.(*types.TypeName)
  128. if !ok {
  129. return nil, errors.New("not a type")
  130. }
  131. return typ.Type().(*types.Named), nil
  132. }