Эх сурвалжийг харах

common: add database/sql support for Hash and Address (#15541)

Vincent Serpoul 7 жил өмнө
parent
commit
2909f6d7a2
2 өөрчлөгдсөн 219 нэмэгдсэн , 2 устгасан
  1. 40 1
      common/types.go
  2. 179 1
      common/types_test.go

+ 40 - 1
common/types.go

@@ -17,6 +17,7 @@
 package common
 
 import (
+	"database/sql/driver"
 	"encoding/hex"
 	"encoding/json"
 	"fmt"
@@ -31,7 +32,9 @@ import (
 
 // Lengths of hashes and addresses in bytes.
 const (
-	HashLength    = 32
+	// HashLength is the expected length of the hash
+	HashLength = 32
+	// AddressLength is the expected length of the adddress
 	AddressLength = 20
 )
 
@@ -120,6 +123,24 @@ func (h Hash) Generate(rand *rand.Rand, size int) reflect.Value {
 	return reflect.ValueOf(h)
 }
 
+// Scan implements Scanner for database/sql.
+func (h *Hash) Scan(src interface{}) error {
+	srcB, ok := src.([]byte)
+	if !ok {
+		return fmt.Errorf("can't scan %T into Hash", src)
+	}
+	if len(srcB) != HashLength {
+		return fmt.Errorf("can't scan []byte of len %d into Hash, want %d", len(srcB), HashLength)
+	}
+	copy(h[:], srcB)
+	return nil
+}
+
+// Value implements valuer for database/sql.
+func (h Hash) Value() (driver.Value, error) {
+	return h[:], nil
+}
+
 // UnprefixedHash allows marshaling a Hash without 0x prefix.
 type UnprefixedHash Hash
 
@@ -229,6 +250,24 @@ func (a *Address) UnmarshalJSON(input []byte) error {
 	return hexutil.UnmarshalFixedJSON(addressT, input, a[:])
 }
 
+// Scan implements Scanner for database/sql.
+func (a *Address) Scan(src interface{}) error {
+	srcB, ok := src.([]byte)
+	if !ok {
+		return fmt.Errorf("can't scan %T into Address", src)
+	}
+	if len(srcB) != AddressLength {
+		return fmt.Errorf("can't scan []byte of len %d into Address, want %d", len(srcB), AddressLength)
+	}
+	copy(a[:], srcB)
+	return nil
+}
+
+// Value implements valuer for database/sql.
+func (a Address) Value() (driver.Value, error) {
+	return a[:], nil
+}
+
 // UnprefixedAddress allows marshaling an Address without 0x prefix.
 type UnprefixedAddress Address
 

+ 179 - 1
common/types_test.go

@@ -17,9 +17,10 @@
 package common
 
 import (
+	"database/sql/driver"
 	"encoding/json"
-
 	"math/big"
+	"reflect"
 	"strings"
 	"testing"
 )
@@ -193,3 +194,180 @@ func TestMixedcaseAccount_Address(t *testing.T) {
 	}
 
 }
+
+func TestHash_Scan(t *testing.T) {
+	type args struct {
+		src interface{}
+	}
+	tests := []struct {
+		name    string
+		args    args
+		wantErr bool
+	}{
+		{
+			name: "working scan",
+			args: args{src: []byte{
+				0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
+				0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+				0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+				0x10, 0x00,
+			}},
+			wantErr: false,
+		},
+		{
+			name:    "non working scan",
+			args:    args{src: int64(1234567890)},
+			wantErr: true,
+		},
+		{
+			name: "invalid length scan",
+			args: args{src: []byte{
+				0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
+				0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+				0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+			}},
+			wantErr: true,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			h := &Hash{}
+			if err := h.Scan(tt.args.src); (err != nil) != tt.wantErr {
+				t.Errorf("Hash.Scan() error = %v, wantErr %v", err, tt.wantErr)
+			}
+
+			if !tt.wantErr {
+				for i := range h {
+					if h[i] != tt.args.src.([]byte)[i] {
+						t.Errorf(
+							"Hash.Scan() didn't scan the %d src correctly (have %X, want %X)",
+							i, h[i], tt.args.src.([]byte)[i],
+						)
+					}
+				}
+			}
+		})
+	}
+}
+
+func TestHash_Value(t *testing.T) {
+	b := []byte{
+		0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
+		0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+		0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+		0x10, 0x00,
+	}
+	var usedH Hash
+	usedH.SetBytes(b)
+	tests := []struct {
+		name    string
+		h       Hash
+		want    driver.Value
+		wantErr bool
+	}{
+		{
+			name:    "Working value",
+			h:       usedH,
+			want:    b,
+			wantErr: false,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := tt.h.Value()
+			if (err != nil) != tt.wantErr {
+				t.Errorf("Hash.Value() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			if !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("Hash.Value() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func TestAddress_Scan(t *testing.T) {
+	type args struct {
+		src interface{}
+	}
+	tests := []struct {
+		name    string
+		args    args
+		wantErr bool
+	}{
+		{
+			name: "working scan",
+			args: args{src: []byte{
+				0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
+				0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+			}},
+			wantErr: false,
+		},
+		{
+			name:    "non working scan",
+			args:    args{src: int64(1234567890)},
+			wantErr: true,
+		},
+		{
+			name: "invalid length scan",
+			args: args{src: []byte{
+				0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
+				0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a,
+			}},
+			wantErr: true,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			a := &Address{}
+			if err := a.Scan(tt.args.src); (err != nil) != tt.wantErr {
+				t.Errorf("Address.Scan() error = %v, wantErr %v", err, tt.wantErr)
+			}
+
+			if !tt.wantErr {
+				for i := range a {
+					if a[i] != tt.args.src.([]byte)[i] {
+						t.Errorf(
+							"Address.Scan() didn't scan the %d src correctly (have %X, want %X)",
+							i, a[i], tt.args.src.([]byte)[i],
+						)
+					}
+				}
+			}
+		})
+	}
+}
+
+func TestAddress_Value(t *testing.T) {
+	b := []byte{
+		0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
+		0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
+	}
+	var usedA Address
+	usedA.SetBytes(b)
+	tests := []struct {
+		name    string
+		a       Address
+		want    driver.Value
+		wantErr bool
+	}{
+		{
+			name:    "Working value",
+			a:       usedA,
+			want:    b,
+			wantErr: false,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := tt.a.Value()
+			if (err != nil) != tt.wantErr {
+				t.Errorf("Address.Value() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+			if !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("Address.Value() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}