mirror

Mirror free and open-source projects you like with minimal effort
git clone git://git.server.ky/slackcoder/mirror
Log | Files | Refs | README

decode.go (17015B)


      1 package toml
      2 
      3 import (
      4 	"bytes"
      5 	"encoding"
      6 	"encoding/json"
      7 	"fmt"
      8 	"io"
      9 	"io/fs"
     10 	"math"
     11 	"os"
     12 	"reflect"
     13 	"strconv"
     14 	"strings"
     15 	"time"
     16 )
     17 
     18 // Unmarshaler is the interface implemented by objects that can unmarshal a
     19 // TOML description of themselves.
     20 type Unmarshaler interface {
     21 	UnmarshalTOML(any) error
     22 }
     23 
     24 // Unmarshal decodes the contents of data in TOML format into a pointer v.
     25 //
     26 // See [Decoder] for a description of the decoding process.
     27 func Unmarshal(data []byte, v any) error {
     28 	_, err := NewDecoder(bytes.NewReader(data)).Decode(v)
     29 	return err
     30 }
     31 
     32 // Decode the TOML data in to the pointer v.
     33 //
     34 // See [Decoder] for a description of the decoding process.
     35 func Decode(data string, v any) (MetaData, error) {
     36 	return NewDecoder(strings.NewReader(data)).Decode(v)
     37 }
     38 
     39 // DecodeFile reads the contents of a file and decodes it with [Decode].
     40 func DecodeFile(path string, v any) (MetaData, error) {
     41 	fp, err := os.Open(path)
     42 	if err != nil {
     43 		return MetaData{}, err
     44 	}
     45 	defer fp.Close()
     46 	return NewDecoder(fp).Decode(v)
     47 }
     48 
     49 // DecodeFS reads the contents of a file from [fs.FS] and decodes it with
     50 // [Decode].
     51 func DecodeFS(fsys fs.FS, path string, v any) (MetaData, error) {
     52 	fp, err := fsys.Open(path)
     53 	if err != nil {
     54 		return MetaData{}, err
     55 	}
     56 	defer fp.Close()
     57 	return NewDecoder(fp).Decode(v)
     58 }
     59 
     60 // Primitive is a TOML value that hasn't been decoded into a Go value.
     61 //
     62 // This type can be used for any value, which will cause decoding to be delayed.
     63 // You can use [PrimitiveDecode] to "manually" decode these values.
     64 //
     65 // NOTE: The underlying representation of a `Primitive` value is subject to
     66 // change. Do not rely on it.
     67 //
     68 // NOTE: Primitive values are still parsed, so using them will only avoid the
     69 // overhead of reflection. They can be useful when you don't know the exact type
     70 // of TOML data until runtime.
     71 type Primitive struct {
     72 	undecoded any
     73 	context   Key
     74 }
     75 
     76 // The significand precision for float32 and float64 is 24 and 53 bits; this is
     77 // the range a natural number can be stored in a float without loss of data.
     78 const (
     79 	maxSafeFloat32Int = 16777215                // 2^24-1
     80 	maxSafeFloat64Int = int64(9007199254740991) // 2^53-1
     81 )
     82 
     83 // Decoder decodes TOML data.
     84 //
     85 // TOML tables correspond to Go structs or maps; they can be used
     86 // interchangeably, but structs offer better type safety.
     87 //
     88 // TOML table arrays correspond to either a slice of structs or a slice of maps.
     89 //
     90 // TOML datetimes correspond to [time.Time]. Local datetimes are parsed in the
     91 // local timezone.
     92 //
     93 // [time.Duration] types are treated as nanoseconds if the TOML value is an
     94 // integer, or they're parsed with time.ParseDuration() if they're strings.
     95 //
     96 // All other TOML types (float, string, int, bool and array) correspond to the
     97 // obvious Go types.
     98 //
     99 // An exception to the above rules is if a type implements the TextUnmarshaler
    100 // interface, in which case any primitive TOML value (floats, strings, integers,
    101 // booleans, datetimes) will be converted to a []byte and given to the value's
    102 // UnmarshalText method. See the Unmarshaler example for a demonstration with
    103 // email addresses.
    104 //
    105 // # Key mapping
    106 //
    107 // TOML keys can map to either keys in a Go map or field names in a Go struct.
    108 // The special `toml` struct tag can be used to map TOML keys to struct fields
    109 // that don't match the key name exactly (see the example). A case insensitive
    110 // match to struct names will be tried if an exact match can't be found.
    111 //
    112 // The mapping between TOML values and Go values is loose. That is, there may
    113 // exist TOML values that cannot be placed into your representation, and there
    114 // may be parts of your representation that do not correspond to TOML values.
    115 // This loose mapping can be made stricter by using the IsDefined and/or
    116 // Undecoded methods on the MetaData returned.
    117 //
    118 // This decoder does not handle cyclic types. Decode will not terminate if a
    119 // cyclic type is passed.
    120 type Decoder struct {
    121 	r io.Reader
    122 }
    123 
    124 // NewDecoder creates a new Decoder.
    125 func NewDecoder(r io.Reader) *Decoder {
    126 	return &Decoder{r: r}
    127 }
    128 
    129 var (
    130 	unmarshalToml = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
    131 	unmarshalText = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
    132 	primitiveType = reflect.TypeOf((*Primitive)(nil)).Elem()
    133 )
    134 
    135 // Decode TOML data in to the pointer `v`.
    136 func (dec *Decoder) Decode(v any) (MetaData, error) {
    137 	rv := reflect.ValueOf(v)
    138 	if rv.Kind() != reflect.Ptr {
    139 		s := "%q"
    140 		if reflect.TypeOf(v) == nil {
    141 			s = "%v"
    142 		}
    143 
    144 		return MetaData{}, fmt.Errorf("toml: cannot decode to non-pointer "+s, reflect.TypeOf(v))
    145 	}
    146 	if rv.IsNil() {
    147 		return MetaData{}, fmt.Errorf("toml: cannot decode to nil value of %q", reflect.TypeOf(v))
    148 	}
    149 
    150 	// Check if this is a supported type: struct, map, any, or something that
    151 	// implements UnmarshalTOML or UnmarshalText.
    152 	rv = indirect(rv)
    153 	rt := rv.Type()
    154 	if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map &&
    155 		!(rv.Kind() == reflect.Interface && rv.NumMethod() == 0) &&
    156 		!rt.Implements(unmarshalToml) && !rt.Implements(unmarshalText) {
    157 		return MetaData{}, fmt.Errorf("toml: cannot decode to type %s", rt)
    158 	}
    159 
    160 	// TODO: parser should read from io.Reader? Or at the very least, make it
    161 	// read from []byte rather than string
    162 	data, err := io.ReadAll(dec.r)
    163 	if err != nil {
    164 		return MetaData{}, err
    165 	}
    166 
    167 	p, err := parse(string(data))
    168 	if err != nil {
    169 		return MetaData{}, err
    170 	}
    171 
    172 	md := MetaData{
    173 		mapping: p.mapping,
    174 		keyInfo: p.keyInfo,
    175 		keys:    p.ordered,
    176 		decoded: make(map[string]struct{}, len(p.ordered)),
    177 		context: nil,
    178 		data:    data,
    179 	}
    180 	return md, md.unify(p.mapping, rv)
    181 }
    182 
    183 // PrimitiveDecode is just like the other Decode* functions, except it decodes a
    184 // TOML value that has already been parsed. Valid primitive values can *only* be
    185 // obtained from values filled by the decoder functions, including this method.
    186 // (i.e., v may contain more [Primitive] values.)
    187 //
    188 // Meta data for primitive values is included in the meta data returned by the
    189 // Decode* functions with one exception: keys returned by the Undecoded method
    190 // will only reflect keys that were decoded. Namely, any keys hidden behind a
    191 // Primitive will be considered undecoded. Executing this method will update the
    192 // undecoded keys in the meta data. (See the example.)
    193 func (md *MetaData) PrimitiveDecode(primValue Primitive, v any) error {
    194 	md.context = primValue.context
    195 	defer func() { md.context = nil }()
    196 	return md.unify(primValue.undecoded, rvalue(v))
    197 }
    198 
    199 // unify performs a sort of type unification based on the structure of `rv`,
    200 // which is the client representation.
    201 //
    202 // Any type mismatch produces an error. Finding a type that we don't know
    203 // how to handle produces an unsupported type error.
    204 func (md *MetaData) unify(data any, rv reflect.Value) error {
    205 	// Special case. Look for a `Primitive` value.
    206 	// TODO: #76 would make this superfluous after implemented.
    207 	if rv.Type() == primitiveType {
    208 		// Save the undecoded data and the key context into the primitive
    209 		// value.
    210 		context := make(Key, len(md.context))
    211 		copy(context, md.context)
    212 		rv.Set(reflect.ValueOf(Primitive{
    213 			undecoded: data,
    214 			context:   context,
    215 		}))
    216 		return nil
    217 	}
    218 
    219 	rvi := rv.Interface()
    220 	if v, ok := rvi.(Unmarshaler); ok {
    221 		err := v.UnmarshalTOML(data)
    222 		if err != nil {
    223 			return md.parseErr(err)
    224 		}
    225 		return nil
    226 	}
    227 	if v, ok := rvi.(encoding.TextUnmarshaler); ok {
    228 		return md.unifyText(data, v)
    229 	}
    230 
    231 	// TODO:
    232 	// The behavior here is incorrect whenever a Go type satisfies the
    233 	// encoding.TextUnmarshaler interface but also corresponds to a TOML hash or
    234 	// array. In particular, the unmarshaler should only be applied to primitive
    235 	// TOML values. But at this point, it will be applied to all kinds of values
    236 	// and produce an incorrect error whenever those values are hashes or arrays
    237 	// (including arrays of tables).
    238 
    239 	k := rv.Kind()
    240 
    241 	if k >= reflect.Int && k <= reflect.Uint64 {
    242 		return md.unifyInt(data, rv)
    243 	}
    244 	switch k {
    245 	case reflect.Struct:
    246 		return md.unifyStruct(data, rv)
    247 	case reflect.Map:
    248 		return md.unifyMap(data, rv)
    249 	case reflect.Array:
    250 		return md.unifyArray(data, rv)
    251 	case reflect.Slice:
    252 		return md.unifySlice(data, rv)
    253 	case reflect.String:
    254 		return md.unifyString(data, rv)
    255 	case reflect.Bool:
    256 		return md.unifyBool(data, rv)
    257 	case reflect.Interface:
    258 		if rv.NumMethod() > 0 { /// Only empty interfaces are supported.
    259 			return md.e("unsupported type %s", rv.Type())
    260 		}
    261 		return md.unifyAnything(data, rv)
    262 	case reflect.Float32, reflect.Float64:
    263 		return md.unifyFloat64(data, rv)
    264 	}
    265 	return md.e("unsupported type %s", rv.Kind())
    266 }
    267 
    268 func (md *MetaData) unifyStruct(mapping any, rv reflect.Value) error {
    269 	tmap, ok := mapping.(map[string]any)
    270 	if !ok {
    271 		if mapping == nil {
    272 			return nil
    273 		}
    274 		return md.e("type mismatch for %s: expected table but found %s", rv.Type().String(), fmtType(mapping))
    275 	}
    276 
    277 	for key, datum := range tmap {
    278 		var f *field
    279 		fields := cachedTypeFields(rv.Type())
    280 		for i := range fields {
    281 			ff := &fields[i]
    282 			if ff.name == key {
    283 				f = ff
    284 				break
    285 			}
    286 			if f == nil && strings.EqualFold(ff.name, key) {
    287 				f = ff
    288 			}
    289 		}
    290 		if f != nil {
    291 			subv := rv
    292 			for _, i := range f.index {
    293 				subv = indirect(subv.Field(i))
    294 			}
    295 
    296 			if isUnifiable(subv) {
    297 				md.decoded[md.context.add(key).String()] = struct{}{}
    298 				md.context = append(md.context, key)
    299 
    300 				err := md.unify(datum, subv)
    301 				if err != nil {
    302 					return err
    303 				}
    304 				md.context = md.context[0 : len(md.context)-1]
    305 			} else if f.name != "" {
    306 				return md.e("cannot write unexported field %s.%s", rv.Type().String(), f.name)
    307 			}
    308 		}
    309 	}
    310 	return nil
    311 }
    312 
    313 func (md *MetaData) unifyMap(mapping any, rv reflect.Value) error {
    314 	keyType := rv.Type().Key().Kind()
    315 	if keyType != reflect.String && keyType != reflect.Interface {
    316 		return fmt.Errorf("toml: cannot decode to a map with non-string key type (%s in %q)",
    317 			keyType, rv.Type())
    318 	}
    319 
    320 	tmap, ok := mapping.(map[string]any)
    321 	if !ok {
    322 		if tmap == nil {
    323 			return nil
    324 		}
    325 		return md.badtype("map", mapping)
    326 	}
    327 	if rv.IsNil() {
    328 		rv.Set(reflect.MakeMap(rv.Type()))
    329 	}
    330 	for k, v := range tmap {
    331 		md.decoded[md.context.add(k).String()] = struct{}{}
    332 		md.context = append(md.context, k)
    333 
    334 		rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
    335 
    336 		err := md.unify(v, indirect(rvval))
    337 		if err != nil {
    338 			return err
    339 		}
    340 		md.context = md.context[0 : len(md.context)-1]
    341 
    342 		rvkey := indirect(reflect.New(rv.Type().Key()))
    343 
    344 		switch keyType {
    345 		case reflect.Interface:
    346 			rvkey.Set(reflect.ValueOf(k))
    347 		case reflect.String:
    348 			rvkey.SetString(k)
    349 		}
    350 
    351 		rv.SetMapIndex(rvkey, rvval)
    352 	}
    353 	return nil
    354 }
    355 
    356 func (md *MetaData) unifyArray(data any, rv reflect.Value) error {
    357 	datav := reflect.ValueOf(data)
    358 	if datav.Kind() != reflect.Slice {
    359 		if !datav.IsValid() {
    360 			return nil
    361 		}
    362 		return md.badtype("slice", data)
    363 	}
    364 	if l := datav.Len(); l != rv.Len() {
    365 		return md.e("expected array length %d; got TOML array of length %d", rv.Len(), l)
    366 	}
    367 	return md.unifySliceArray(datav, rv)
    368 }
    369 
    370 func (md *MetaData) unifySlice(data any, rv reflect.Value) error {
    371 	datav := reflect.ValueOf(data)
    372 	if datav.Kind() != reflect.Slice {
    373 		if !datav.IsValid() {
    374 			return nil
    375 		}
    376 		return md.badtype("slice", data)
    377 	}
    378 	n := datav.Len()
    379 	if rv.IsNil() || rv.Cap() < n {
    380 		rv.Set(reflect.MakeSlice(rv.Type(), n, n))
    381 	}
    382 	rv.SetLen(n)
    383 	return md.unifySliceArray(datav, rv)
    384 }
    385 
    386 func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
    387 	l := data.Len()
    388 	for i := 0; i < l; i++ {
    389 		err := md.unify(data.Index(i).Interface(), indirect(rv.Index(i)))
    390 		if err != nil {
    391 			return err
    392 		}
    393 	}
    394 	return nil
    395 }
    396 
    397 func (md *MetaData) unifyString(data any, rv reflect.Value) error {
    398 	_, ok := rv.Interface().(json.Number)
    399 	if ok {
    400 		if i, ok := data.(int64); ok {
    401 			rv.SetString(strconv.FormatInt(i, 10))
    402 		} else if f, ok := data.(float64); ok {
    403 			rv.SetString(strconv.FormatFloat(f, 'f', -1, 64))
    404 		} else {
    405 			return md.badtype("string", data)
    406 		}
    407 		return nil
    408 	}
    409 
    410 	if s, ok := data.(string); ok {
    411 		rv.SetString(s)
    412 		return nil
    413 	}
    414 	return md.badtype("string", data)
    415 }
    416 
    417 func (md *MetaData) unifyFloat64(data any, rv reflect.Value) error {
    418 	rvk := rv.Kind()
    419 
    420 	if num, ok := data.(float64); ok {
    421 		switch rvk {
    422 		case reflect.Float32:
    423 			if num < -math.MaxFloat32 || num > math.MaxFloat32 {
    424 				return md.parseErr(errParseRange{i: num, size: rvk.String()})
    425 			}
    426 			fallthrough
    427 		case reflect.Float64:
    428 			rv.SetFloat(num)
    429 		default:
    430 			panic("bug")
    431 		}
    432 		return nil
    433 	}
    434 
    435 	if num, ok := data.(int64); ok {
    436 		if (rvk == reflect.Float32 && (num < -maxSafeFloat32Int || num > maxSafeFloat32Int)) ||
    437 			(rvk == reflect.Float64 && (num < -maxSafeFloat64Int || num > maxSafeFloat64Int)) {
    438 			return md.parseErr(errUnsafeFloat{i: num, size: rvk.String()})
    439 		}
    440 		rv.SetFloat(float64(num))
    441 		return nil
    442 	}
    443 
    444 	return md.badtype("float", data)
    445 }
    446 
    447 func (md *MetaData) unifyInt(data any, rv reflect.Value) error {
    448 	_, ok := rv.Interface().(time.Duration)
    449 	if ok {
    450 		// Parse as string duration, and fall back to regular integer parsing
    451 		// (as nanosecond) if this is not a string.
    452 		if s, ok := data.(string); ok {
    453 			dur, err := time.ParseDuration(s)
    454 			if err != nil {
    455 				return md.parseErr(errParseDuration{s})
    456 			}
    457 			rv.SetInt(int64(dur))
    458 			return nil
    459 		}
    460 	}
    461 
    462 	num, ok := data.(int64)
    463 	if !ok {
    464 		return md.badtype("integer", data)
    465 	}
    466 
    467 	rvk := rv.Kind()
    468 	switch {
    469 	case rvk >= reflect.Int && rvk <= reflect.Int64:
    470 		if (rvk == reflect.Int8 && (num < math.MinInt8 || num > math.MaxInt8)) ||
    471 			(rvk == reflect.Int16 && (num < math.MinInt16 || num > math.MaxInt16)) ||
    472 			(rvk == reflect.Int32 && (num < math.MinInt32 || num > math.MaxInt32)) {
    473 			return md.parseErr(errParseRange{i: num, size: rvk.String()})
    474 		}
    475 		rv.SetInt(num)
    476 	case rvk >= reflect.Uint && rvk <= reflect.Uint64:
    477 		unum := uint64(num)
    478 		if rvk == reflect.Uint8 && (num < 0 || unum > math.MaxUint8) ||
    479 			rvk == reflect.Uint16 && (num < 0 || unum > math.MaxUint16) ||
    480 			rvk == reflect.Uint32 && (num < 0 || unum > math.MaxUint32) {
    481 			return md.parseErr(errParseRange{i: num, size: rvk.String()})
    482 		}
    483 		rv.SetUint(unum)
    484 	default:
    485 		panic("unreachable")
    486 	}
    487 	return nil
    488 }
    489 
    490 func (md *MetaData) unifyBool(data any, rv reflect.Value) error {
    491 	if b, ok := data.(bool); ok {
    492 		rv.SetBool(b)
    493 		return nil
    494 	}
    495 	return md.badtype("boolean", data)
    496 }
    497 
    498 func (md *MetaData) unifyAnything(data any, rv reflect.Value) error {
    499 	rv.Set(reflect.ValueOf(data))
    500 	return nil
    501 }
    502 
    503 func (md *MetaData) unifyText(data any, v encoding.TextUnmarshaler) error {
    504 	var s string
    505 	switch sdata := data.(type) {
    506 	case Marshaler:
    507 		text, err := sdata.MarshalTOML()
    508 		if err != nil {
    509 			return err
    510 		}
    511 		s = string(text)
    512 	case encoding.TextMarshaler:
    513 		text, err := sdata.MarshalText()
    514 		if err != nil {
    515 			return err
    516 		}
    517 		s = string(text)
    518 	case fmt.Stringer:
    519 		s = sdata.String()
    520 	case string:
    521 		s = sdata
    522 	case bool:
    523 		s = fmt.Sprintf("%v", sdata)
    524 	case int64:
    525 		s = fmt.Sprintf("%d", sdata)
    526 	case float64:
    527 		s = fmt.Sprintf("%f", sdata)
    528 	default:
    529 		return md.badtype("primitive (string-like)", data)
    530 	}
    531 	if err := v.UnmarshalText([]byte(s)); err != nil {
    532 		return md.parseErr(err)
    533 	}
    534 	return nil
    535 }
    536 
    537 func (md *MetaData) badtype(dst string, data any) error {
    538 	return md.e("incompatible types: TOML value has type %s; destination has type %s", fmtType(data), dst)
    539 }
    540 
    541 func (md *MetaData) parseErr(err error) error {
    542 	k := md.context.String()
    543 	return ParseError{
    544 		LastKey:  k,
    545 		Position: md.keyInfo[k].pos,
    546 		Line:     md.keyInfo[k].pos.Line,
    547 		err:      err,
    548 		input:    string(md.data),
    549 	}
    550 }
    551 
    552 func (md *MetaData) e(format string, args ...any) error {
    553 	f := "toml: "
    554 	if len(md.context) > 0 {
    555 		f = fmt.Sprintf("toml: (last key %q): ", md.context)
    556 		p := md.keyInfo[md.context.String()].pos
    557 		if p.Line > 0 {
    558 			f = fmt.Sprintf("toml: line %d (last key %q): ", p.Line, md.context)
    559 		}
    560 	}
    561 	return fmt.Errorf(f+format, args...)
    562 }
    563 
    564 // rvalue returns a reflect.Value of `v`. All pointers are resolved.
    565 func rvalue(v any) reflect.Value {
    566 	return indirect(reflect.ValueOf(v))
    567 }
    568 
    569 // indirect returns the value pointed to by a pointer.
    570 //
    571 // Pointers are followed until the value is not a pointer. New values are
    572 // allocated for each nil pointer.
    573 //
    574 // An exception to this rule is if the value satisfies an interface of interest
    575 // to us (like encoding.TextUnmarshaler).
    576 func indirect(v reflect.Value) reflect.Value {
    577 	if v.Kind() != reflect.Ptr {
    578 		if v.CanSet() {
    579 			pv := v.Addr()
    580 			pvi := pv.Interface()
    581 			if _, ok := pvi.(encoding.TextUnmarshaler); ok {
    582 				return pv
    583 			}
    584 			if _, ok := pvi.(Unmarshaler); ok {
    585 				return pv
    586 			}
    587 		}
    588 		return v
    589 	}
    590 	if v.IsNil() {
    591 		v.Set(reflect.New(v.Type().Elem()))
    592 	}
    593 	return indirect(reflect.Indirect(v))
    594 }
    595 
    596 func isUnifiable(rv reflect.Value) bool {
    597 	if rv.CanSet() {
    598 		return true
    599 	}
    600 	rvi := rv.Interface()
    601 	if _, ok := rvi.(encoding.TextUnmarshaler); ok {
    602 		return true
    603 	}
    604 	if _, ok := rvi.(Unmarshaler); ok {
    605 		return true
    606 	}
    607 	return false
    608 }
    609 
    610 // fmt %T with "interface {}" replaced with "any", which is far more readable.
    611 func fmtType(t any) string {
    612 	return strings.ReplaceAll(fmt.Sprintf("%T", t), "interface {}", "any")
    613 }