You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

600 lines
20 KiB

  1. package validator
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "reflect"
  7. "strings"
  8. "sync"
  9. "time"
  10. ut "github.com/go-playground/universal-translator"
  11. )
  12. const (
  13. defaultTagName = "validate"
  14. utf8HexComma = "0x2C"
  15. utf8Pipe = "0x7C"
  16. tagSeparator = ","
  17. orSeparator = "|"
  18. tagKeySeparator = "="
  19. structOnlyTag = "structonly"
  20. noStructLevelTag = "nostructlevel"
  21. omitempty = "omitempty"
  22. isdefault = "isdefault"
  23. skipValidationTag = "-"
  24. diveTag = "dive"
  25. keysTag = "keys"
  26. endKeysTag = "endkeys"
  27. requiredTag = "required"
  28. namespaceSeparator = "."
  29. leftBracket = "["
  30. rightBracket = "]"
  31. restrictedTagChars = ".[],|=+()`~!@#$%^&*\\\"/?<>{}"
  32. restrictedAliasErr = "Alias '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
  33. restrictedTagErr = "Tag '%s' either contains restricted characters or is the same as a restricted tag needed for normal operation"
  34. )
  35. var (
  36. timeType = reflect.TypeOf(time.Time{})
  37. defaultCField = &cField{namesEqual: true}
  38. )
  39. // FilterFunc is the type used to filter fields using
  40. // StructFiltered(...) function.
  41. // returning true results in the field being filtered/skiped from
  42. // validation
  43. type FilterFunc func(ns []byte) bool
  44. // CustomTypeFunc allows for overriding or adding custom field type handler functions
  45. // field = field value of the type to return a value to be validated
  46. // example Valuer from sql drive see https://golang.org/src/database/sql/driver/types.go?s=1210:1293#L29
  47. type CustomTypeFunc func(field reflect.Value) interface{}
  48. // TagNameFunc allows for adding of a custom tag name parser
  49. type TagNameFunc func(field reflect.StructField) string
  50. // Validate contains the validator settings and cache
  51. type Validate struct {
  52. tagName string
  53. pool *sync.Pool
  54. hasCustomFuncs bool
  55. hasTagNameFunc bool
  56. tagNameFunc TagNameFunc
  57. structLevelFuncs map[reflect.Type]StructLevelFuncCtx
  58. customFuncs map[reflect.Type]CustomTypeFunc
  59. aliases map[string]string
  60. validations map[string]FuncCtx
  61. transTagFunc map[ut.Translator]map[string]TranslationFunc // map[<locale>]map[<tag>]TranslationFunc
  62. tagCache *tagCache
  63. structCache *structCache
  64. }
  65. // New returns a new instance of 'validate' with sane defaults.
  66. func New() *Validate {
  67. tc := new(tagCache)
  68. tc.m.Store(make(map[string]*cTag))
  69. sc := new(structCache)
  70. sc.m.Store(make(map[reflect.Type]*cStruct))
  71. v := &Validate{
  72. tagName: defaultTagName,
  73. aliases: make(map[string]string, len(bakedInAliases)),
  74. validations: make(map[string]FuncCtx, len(bakedInValidators)),
  75. tagCache: tc,
  76. structCache: sc,
  77. }
  78. // must copy alias validators for separate validations to be used in each validator instance
  79. for k, val := range bakedInAliases {
  80. v.RegisterAlias(k, val)
  81. }
  82. // must copy validators for separate validations to be used in each instance
  83. for k, val := range bakedInValidators {
  84. // no need to error check here, baked in will always be valid
  85. _ = v.registerValidation(k, wrapFunc(val), true)
  86. }
  87. v.pool = &sync.Pool{
  88. New: func() interface{} {
  89. return &validate{
  90. v: v,
  91. ns: make([]byte, 0, 64),
  92. actualNs: make([]byte, 0, 64),
  93. misc: make([]byte, 32),
  94. }
  95. },
  96. }
  97. return v
  98. }
  99. // SetTagName allows for changing of the default tag name of 'validate'
  100. func (v *Validate) SetTagName(name string) {
  101. v.tagName = name
  102. }
  103. // RegisterTagNameFunc registers a function to get alternate names for StructFields.
  104. //
  105. // eg. to use the names which have been specified for JSON representations of structs, rather than normal Go field names:
  106. //
  107. // validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
  108. // name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
  109. // if name == "-" {
  110. // return ""
  111. // }
  112. // return name
  113. // })
  114. func (v *Validate) RegisterTagNameFunc(fn TagNameFunc) {
  115. v.tagNameFunc = fn
  116. v.hasTagNameFunc = true
  117. }
  118. // RegisterValidation adds a validation with the given tag
  119. //
  120. // NOTES:
  121. // - if the key already exists, the previous validation function will be replaced.
  122. // - this method is not thread-safe it is intended that these all be registered prior to any validation
  123. func (v *Validate) RegisterValidation(tag string, fn Func) error {
  124. return v.RegisterValidationCtx(tag, wrapFunc(fn))
  125. }
  126. // RegisterValidationCtx does the same as RegisterValidation on accepts a FuncCtx validation
  127. // allowing context.Context validation support.
  128. func (v *Validate) RegisterValidationCtx(tag string, fn FuncCtx) error {
  129. return v.registerValidation(tag, fn, false)
  130. }
  131. func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool) error {
  132. if len(tag) == 0 {
  133. return errors.New("Function Key cannot be empty")
  134. }
  135. if fn == nil {
  136. return errors.New("Function cannot be empty")
  137. }
  138. _, ok := restrictedTags[tag]
  139. if !bakedIn && (ok || strings.ContainsAny(tag, restrictedTagChars)) {
  140. panic(fmt.Sprintf(restrictedTagErr, tag))
  141. }
  142. v.validations[tag] = fn
  143. return nil
  144. }
  145. // RegisterAlias registers a mapping of a single validation tag that
  146. // defines a common or complex set of validation(s) to simplify adding validation
  147. // to structs.
  148. //
  149. // NOTE: this function is not thread-safe it is intended that these all be registered prior to any validation
  150. func (v *Validate) RegisterAlias(alias, tags string) {
  151. _, ok := restrictedTags[alias]
  152. if ok || strings.ContainsAny(alias, restrictedTagChars) {
  153. panic(fmt.Sprintf(restrictedAliasErr, alias))
  154. }
  155. v.aliases[alias] = tags
  156. }
  157. // RegisterStructValidation registers a StructLevelFunc against a number of types.
  158. //
  159. // NOTE:
  160. // - this method is not thread-safe it is intended that these all be registered prior to any validation
  161. func (v *Validate) RegisterStructValidation(fn StructLevelFunc, types ...interface{}) {
  162. v.RegisterStructValidationCtx(wrapStructLevelFunc(fn), types...)
  163. }
  164. // RegisterStructValidationCtx registers a StructLevelFuncCtx against a number of types and allows passing
  165. // of contextual validation information via context.Context.
  166. //
  167. // NOTE:
  168. // - this method is not thread-safe it is intended that these all be registered prior to any validation
  169. func (v *Validate) RegisterStructValidationCtx(fn StructLevelFuncCtx, types ...interface{}) {
  170. if v.structLevelFuncs == nil {
  171. v.structLevelFuncs = make(map[reflect.Type]StructLevelFuncCtx)
  172. }
  173. for _, t := range types {
  174. tv := reflect.ValueOf(t)
  175. if tv.Kind() == reflect.Ptr {
  176. t = reflect.Indirect(tv).Interface()
  177. }
  178. v.structLevelFuncs[reflect.TypeOf(t)] = fn
  179. }
  180. }
  181. // RegisterCustomTypeFunc registers a CustomTypeFunc against a number of types
  182. //
  183. // NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
  184. func (v *Validate) RegisterCustomTypeFunc(fn CustomTypeFunc, types ...interface{}) {
  185. if v.customFuncs == nil {
  186. v.customFuncs = make(map[reflect.Type]CustomTypeFunc)
  187. }
  188. for _, t := range types {
  189. v.customFuncs[reflect.TypeOf(t)] = fn
  190. }
  191. v.hasCustomFuncs = true
  192. }
  193. // RegisterTranslation registers translations against the provided tag.
  194. func (v *Validate) RegisterTranslation(tag string, trans ut.Translator, registerFn RegisterTranslationsFunc, translationFn TranslationFunc) (err error) {
  195. if v.transTagFunc == nil {
  196. v.transTagFunc = make(map[ut.Translator]map[string]TranslationFunc)
  197. }
  198. if err = registerFn(trans); err != nil {
  199. return
  200. }
  201. m, ok := v.transTagFunc[trans]
  202. if !ok {
  203. m = make(map[string]TranslationFunc)
  204. v.transTagFunc[trans] = m
  205. }
  206. m[tag] = translationFn
  207. return
  208. }
  209. // Struct validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified.
  210. //
  211. // It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
  212. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
  213. func (v *Validate) Struct(s interface{}) error {
  214. return v.StructCtx(context.Background(), s)
  215. }
  216. // StructCtx validates a structs exposed fields, and automatically validates nested structs, unless otherwise specified
  217. // and also allows passing of context.Context for contextual validation information.
  218. //
  219. // It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
  220. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
  221. func (v *Validate) StructCtx(ctx context.Context, s interface{}) (err error) {
  222. val := reflect.ValueOf(s)
  223. top := val
  224. if val.Kind() == reflect.Ptr && !val.IsNil() {
  225. val = val.Elem()
  226. }
  227. if val.Kind() != reflect.Struct || val.Type() == timeType {
  228. return &InvalidValidationError{Type: reflect.TypeOf(s)}
  229. }
  230. // good to validate
  231. vd := v.pool.Get().(*validate)
  232. vd.top = top
  233. vd.isPartial = false
  234. // vd.hasExcludes = false // only need to reset in StructPartial and StructExcept
  235. vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
  236. if len(vd.errs) > 0 {
  237. err = vd.errs
  238. vd.errs = nil
  239. }
  240. v.pool.Put(vd)
  241. return
  242. }
  243. // StructFiltered validates a structs exposed fields, that pass the FilterFunc check and automatically validates
  244. // nested structs, unless otherwise specified.
  245. //
  246. // It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
  247. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
  248. func (v *Validate) StructFiltered(s interface{}, fn FilterFunc) error {
  249. return v.StructFilteredCtx(context.Background(), s, fn)
  250. }
  251. // StructFilteredCtx validates a structs exposed fields, that pass the FilterFunc check and automatically validates
  252. // nested structs, unless otherwise specified and also allows passing of contextual validation information via
  253. // context.Context
  254. //
  255. // It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
  256. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
  257. func (v *Validate) StructFilteredCtx(ctx context.Context, s interface{}, fn FilterFunc) (err error) {
  258. val := reflect.ValueOf(s)
  259. top := val
  260. if val.Kind() == reflect.Ptr && !val.IsNil() {
  261. val = val.Elem()
  262. }
  263. if val.Kind() != reflect.Struct || val.Type() == timeType {
  264. return &InvalidValidationError{Type: reflect.TypeOf(s)}
  265. }
  266. // good to validate
  267. vd := v.pool.Get().(*validate)
  268. vd.top = top
  269. vd.isPartial = true
  270. vd.ffn = fn
  271. // vd.hasExcludes = false // only need to reset in StructPartial and StructExcept
  272. vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
  273. if len(vd.errs) > 0 {
  274. err = vd.errs
  275. vd.errs = nil
  276. }
  277. v.pool.Put(vd)
  278. return
  279. }
  280. // StructPartial validates the fields passed in only, ignoring all others.
  281. // Fields may be provided in a namespaced fashion relative to the struct provided
  282. // eg. NestedStruct.Field or NestedArrayField[0].Struct.Name
  283. //
  284. // It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
  285. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
  286. func (v *Validate) StructPartial(s interface{}, fields ...string) error {
  287. return v.StructPartialCtx(context.Background(), s, fields...)
  288. }
  289. // StructPartialCtx validates the fields passed in only, ignoring all others and allows passing of contextual
  290. // validation validation information via context.Context
  291. // Fields may be provided in a namespaced fashion relative to the struct provided
  292. // eg. NestedStruct.Field or NestedArrayField[0].Struct.Name
  293. //
  294. // It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
  295. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
  296. func (v *Validate) StructPartialCtx(ctx context.Context, s interface{}, fields ...string) (err error) {
  297. val := reflect.ValueOf(s)
  298. top := val
  299. if val.Kind() == reflect.Ptr && !val.IsNil() {
  300. val = val.Elem()
  301. }
  302. if val.Kind() != reflect.Struct || val.Type() == timeType {
  303. return &InvalidValidationError{Type: reflect.TypeOf(s)}
  304. }
  305. // good to validate
  306. vd := v.pool.Get().(*validate)
  307. vd.top = top
  308. vd.isPartial = true
  309. vd.ffn = nil
  310. vd.hasExcludes = false
  311. vd.includeExclude = make(map[string]struct{})
  312. typ := val.Type()
  313. name := typ.Name()
  314. for _, k := range fields {
  315. flds := strings.Split(k, namespaceSeparator)
  316. if len(flds) > 0 {
  317. vd.misc = append(vd.misc[0:0], name...)
  318. vd.misc = append(vd.misc, '.')
  319. for _, s := range flds {
  320. idx := strings.Index(s, leftBracket)
  321. if idx != -1 {
  322. for idx != -1 {
  323. vd.misc = append(vd.misc, s[:idx]...)
  324. vd.includeExclude[string(vd.misc)] = struct{}{}
  325. idx2 := strings.Index(s, rightBracket)
  326. idx2++
  327. vd.misc = append(vd.misc, s[idx:idx2]...)
  328. vd.includeExclude[string(vd.misc)] = struct{}{}
  329. s = s[idx2:]
  330. idx = strings.Index(s, leftBracket)
  331. }
  332. } else {
  333. vd.misc = append(vd.misc, s...)
  334. vd.includeExclude[string(vd.misc)] = struct{}{}
  335. }
  336. vd.misc = append(vd.misc, '.')
  337. }
  338. }
  339. }
  340. vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
  341. if len(vd.errs) > 0 {
  342. err = vd.errs
  343. vd.errs = nil
  344. }
  345. v.pool.Put(vd)
  346. return
  347. }
  348. // StructExcept validates all fields except the ones passed in.
  349. // Fields may be provided in a namespaced fashion relative to the struct provided
  350. // i.e. NestedStruct.Field or NestedArrayField[0].Struct.Name
  351. //
  352. // It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
  353. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
  354. func (v *Validate) StructExcept(s interface{}, fields ...string) error {
  355. return v.StructExceptCtx(context.Background(), s, fields...)
  356. }
  357. // StructExceptCtx validates all fields except the ones passed in and allows passing of contextual
  358. // validation validation information via context.Context
  359. // Fields may be provided in a namespaced fashion relative to the struct provided
  360. // i.e. NestedStruct.Field or NestedArrayField[0].Struct.Name
  361. //
  362. // It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
  363. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
  364. func (v *Validate) StructExceptCtx(ctx context.Context, s interface{}, fields ...string) (err error) {
  365. val := reflect.ValueOf(s)
  366. top := val
  367. if val.Kind() == reflect.Ptr && !val.IsNil() {
  368. val = val.Elem()
  369. }
  370. if val.Kind() != reflect.Struct || val.Type() == timeType {
  371. return &InvalidValidationError{Type: reflect.TypeOf(s)}
  372. }
  373. // good to validate
  374. vd := v.pool.Get().(*validate)
  375. vd.top = top
  376. vd.isPartial = true
  377. vd.ffn = nil
  378. vd.hasExcludes = true
  379. vd.includeExclude = make(map[string]struct{})
  380. typ := val.Type()
  381. name := typ.Name()
  382. for _, key := range fields {
  383. vd.misc = vd.misc[0:0]
  384. if len(name) > 0 {
  385. vd.misc = append(vd.misc, name...)
  386. vd.misc = append(vd.misc, '.')
  387. }
  388. vd.misc = append(vd.misc, key...)
  389. vd.includeExclude[string(vd.misc)] = struct{}{}
  390. }
  391. vd.validateStruct(ctx, top, val, typ, vd.ns[0:0], vd.actualNs[0:0], nil)
  392. if len(vd.errs) > 0 {
  393. err = vd.errs
  394. vd.errs = nil
  395. }
  396. v.pool.Put(vd)
  397. return
  398. }
  399. // Var validates a single variable using tag style validation.
  400. // eg.
  401. // var i int
  402. // validate.Var(i, "gt=1,lt=10")
  403. //
  404. // WARNING: a struct can be passed for validation eg. time.Time is a struct or
  405. // if you have a custom type and have registered a custom type handler, so must
  406. // allow it; however unforeseen validations will occur if trying to validate a
  407. // struct that is meant to be passed to 'validate.Struct'
  408. //
  409. // It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
  410. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
  411. // validate Array, Slice and maps fields which may contain more than one error
  412. func (v *Validate) Var(field interface{}, tag string) error {
  413. return v.VarCtx(context.Background(), field, tag)
  414. }
  415. // VarCtx validates a single variable using tag style validation and allows passing of contextual
  416. // validation validation information via context.Context.
  417. // eg.
  418. // var i int
  419. // validate.Var(i, "gt=1,lt=10")
  420. //
  421. // WARNING: a struct can be passed for validation eg. time.Time is a struct or
  422. // if you have a custom type and have registered a custom type handler, so must
  423. // allow it; however unforeseen validations will occur if trying to validate a
  424. // struct that is meant to be passed to 'validate.Struct'
  425. //
  426. // It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
  427. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
  428. // validate Array, Slice and maps fields which may contain more than one error
  429. func (v *Validate) VarCtx(ctx context.Context, field interface{}, tag string) (err error) {
  430. if len(tag) == 0 || tag == skipValidationTag {
  431. return nil
  432. }
  433. ctag := v.fetchCacheTag(tag)
  434. val := reflect.ValueOf(field)
  435. vd := v.pool.Get().(*validate)
  436. vd.top = val
  437. vd.isPartial = false
  438. vd.traverseField(ctx, val, val, vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
  439. if len(vd.errs) > 0 {
  440. err = vd.errs
  441. vd.errs = nil
  442. }
  443. v.pool.Put(vd)
  444. return
  445. }
  446. // VarWithValue validates a single variable, against another variable/field's value using tag style validation
  447. // eg.
  448. // s1 := "abcd"
  449. // s2 := "abcd"
  450. // validate.VarWithValue(s1, s2, "eqcsfield") // returns true
  451. //
  452. // WARNING: a struct can be passed for validation eg. time.Time is a struct or
  453. // if you have a custom type and have registered a custom type handler, so must
  454. // allow it; however unforeseen validations will occur if trying to validate a
  455. // struct that is meant to be passed to 'validate.Struct'
  456. //
  457. // It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
  458. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
  459. // validate Array, Slice and maps fields which may contain more than one error
  460. func (v *Validate) VarWithValue(field interface{}, other interface{}, tag string) error {
  461. return v.VarWithValueCtx(context.Background(), field, other, tag)
  462. }
  463. // VarWithValueCtx validates a single variable, against another variable/field's value using tag style validation and
  464. // allows passing of contextual validation validation information via context.Context.
  465. // eg.
  466. // s1 := "abcd"
  467. // s2 := "abcd"
  468. // validate.VarWithValue(s1, s2, "eqcsfield") // returns true
  469. //
  470. // WARNING: a struct can be passed for validation eg. time.Time is a struct or
  471. // if you have a custom type and have registered a custom type handler, so must
  472. // allow it; however unforeseen validations will occur if trying to validate a
  473. // struct that is meant to be passed to 'validate.Struct'
  474. //
  475. // It returns InvalidValidationError for bad values passed in and nil or ValidationErrors as error otherwise.
  476. // You will need to assert the error if it's not nil eg. err.(validator.ValidationErrors) to access the array of errors.
  477. // validate Array, Slice and maps fields which may contain more than one error
  478. func (v *Validate) VarWithValueCtx(ctx context.Context, field interface{}, other interface{}, tag string) (err error) {
  479. if len(tag) == 0 || tag == skipValidationTag {
  480. return nil
  481. }
  482. ctag := v.fetchCacheTag(tag)
  483. otherVal := reflect.ValueOf(other)
  484. vd := v.pool.Get().(*validate)
  485. vd.top = otherVal
  486. vd.isPartial = false
  487. vd.traverseField(ctx, otherVal, reflect.ValueOf(field), vd.ns[0:0], vd.actualNs[0:0], defaultCField, ctag)
  488. if len(vd.errs) > 0 {
  489. err = vd.errs
  490. vd.errs = nil
  491. }
  492. v.pool.Put(vd)
  493. return
  494. }