gen_test.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. // Copyright 2022 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package main
  17. import (
  18. "bytes"
  19. "fmt"
  20. "go/ast"
  21. "go/importer"
  22. "go/parser"
  23. "go/token"
  24. "go/types"
  25. "os"
  26. "path/filepath"
  27. "strings"
  28. "testing"
  29. )
  30. // Package RLP is loaded only once and reused for all tests.
  31. var (
  32. testFset = token.NewFileSet()
  33. testImporter = importer.ForCompiler(testFset, "source", nil).(types.ImporterFrom)
  34. testPackageRLP *types.Package
  35. )
  36. func init() {
  37. cwd, err := os.Getwd()
  38. if err != nil {
  39. panic(err)
  40. }
  41. testPackageRLP, err = testImporter.ImportFrom(pathOfPackageRLP, cwd, 0)
  42. if err != nil {
  43. panic(fmt.Errorf("can't load package RLP: %v", err))
  44. }
  45. }
  46. var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint"}
  47. func TestOutput(t *testing.T) {
  48. for _, test := range tests {
  49. test := test
  50. t.Run(test, func(t *testing.T) {
  51. inputFile := filepath.Join("testdata", test+".in.txt")
  52. outputFile := filepath.Join("testdata", test+".out.txt")
  53. bctx, typ, err := loadTestSource(inputFile, "Test")
  54. if err != nil {
  55. t.Fatal("error loading test source:", err)
  56. }
  57. output, err := bctx.generate(typ, true, true)
  58. if err != nil {
  59. t.Fatal("error in generate:", err)
  60. }
  61. // Set this environment variable to regenerate the test outputs.
  62. if os.Getenv("WRITE_TEST_FILES") != "" {
  63. os.WriteFile(outputFile, output, 0644)
  64. }
  65. // Check if output matches.
  66. wantOutput, err := os.ReadFile(outputFile)
  67. if err != nil {
  68. t.Fatal("error loading expected test output:", err)
  69. }
  70. output_string := strings.ReplaceAll(string(output), "\r\n", "\n")
  71. output = []byte(output_string)
  72. wantOutput_string := strings.ReplaceAll(string(wantOutput), "\r\n", "\n")
  73. wantOutput = []byte(wantOutput_string)
  74. if !bytes.Equal(output, wantOutput) {
  75. t.Fatal("output mismatch:\n", string(output))
  76. }
  77. })
  78. }
  79. }
  80. func loadTestSource(file string, typeName string) (*buildContext, *types.Named, error) {
  81. // Load the test input.
  82. content, err := os.ReadFile(file)
  83. if err != nil {
  84. return nil, nil, err
  85. }
  86. f, err := parser.ParseFile(testFset, file, content, 0)
  87. if err != nil {
  88. return nil, nil, err
  89. }
  90. conf := types.Config{Importer: testImporter}
  91. pkg, err := conf.Check("test", testFset, []*ast.File{f}, nil)
  92. if err != nil {
  93. return nil, nil, err
  94. }
  95. // Find the test struct.
  96. bctx := newBuildContext(testPackageRLP)
  97. typ, err := lookupStructType(pkg.Scope(), typeName)
  98. if err != nil {
  99. return nil, nil, fmt.Errorf("can't find type %s: %v", typeName, err)
  100. }
  101. return bctx, typ, nil
  102. }