You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

337 lines
7.2 KiB

  1. package validator
  2. import (
  3. "fmt"
  4. "reflect"
  5. "strings"
  6. "sync"
  7. "sync/atomic"
  8. )
  9. type tagType uint8
  10. const (
  11. typeDefault tagType = iota
  12. typeOmitEmpty
  13. typeIsDefault
  14. typeNoStructLevel
  15. typeStructOnly
  16. typeDive
  17. typeOr
  18. typeKeys
  19. typeEndKeys
  20. )
  21. const (
  22. invalidValidation = "Invalid validation tag on field '%s'"
  23. undefinedValidation = "Undefined validation function '%s' on field '%s'"
  24. keysTagNotDefined = "'" + endKeysTag + "' tag encountered without a corresponding '" + keysTag + "' tag"
  25. )
  26. type structCache struct {
  27. lock sync.Mutex
  28. m atomic.Value // map[reflect.Type]*cStruct
  29. }
  30. func (sc *structCache) Get(key reflect.Type) (c *cStruct, found bool) {
  31. c, found = sc.m.Load().(map[reflect.Type]*cStruct)[key]
  32. return
  33. }
  34. func (sc *structCache) Set(key reflect.Type, value *cStruct) {
  35. m := sc.m.Load().(map[reflect.Type]*cStruct)
  36. nm := make(map[reflect.Type]*cStruct, len(m)+1)
  37. for k, v := range m {
  38. nm[k] = v
  39. }
  40. nm[key] = value
  41. sc.m.Store(nm)
  42. }
  43. type tagCache struct {
  44. lock sync.Mutex
  45. m atomic.Value // map[string]*cTag
  46. }
  47. func (tc *tagCache) Get(key string) (c *cTag, found bool) {
  48. c, found = tc.m.Load().(map[string]*cTag)[key]
  49. return
  50. }
  51. func (tc *tagCache) Set(key string, value *cTag) {
  52. m := tc.m.Load().(map[string]*cTag)
  53. nm := make(map[string]*cTag, len(m)+1)
  54. for k, v := range m {
  55. nm[k] = v
  56. }
  57. nm[key] = value
  58. tc.m.Store(nm)
  59. }
  60. type cStruct struct {
  61. name string
  62. fields []*cField
  63. fn StructLevelFuncCtx
  64. }
  65. type cField struct {
  66. idx int
  67. name string
  68. altName string
  69. namesEqual bool
  70. cTags *cTag
  71. }
  72. type cTag struct {
  73. tag string
  74. aliasTag string
  75. actualAliasTag string
  76. param string
  77. keys *cTag // only populated when using tag's 'keys' and 'endkeys' for map key validation
  78. next *cTag
  79. fn FuncCtx
  80. typeof tagType
  81. hasTag bool
  82. hasAlias bool
  83. hasParam bool // true if parameter used eg. eq= where the equal sign has been set
  84. isBlockEnd bool // indicates the current tag represents the last validation in the block
  85. }
  86. func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStruct {
  87. v.structCache.lock.Lock()
  88. defer v.structCache.lock.Unlock() // leave as defer! because if inner panics, it will never get unlocked otherwise!
  89. typ := current.Type()
  90. // could have been multiple trying to access, but once first is done this ensures struct
  91. // isn't parsed again.
  92. cs, ok := v.structCache.Get(typ)
  93. if ok {
  94. return cs
  95. }
  96. cs = &cStruct{name: sName, fields: make([]*cField, 0), fn: v.structLevelFuncs[typ]}
  97. numFields := current.NumField()
  98. var ctag *cTag
  99. var fld reflect.StructField
  100. var tag string
  101. var customName string
  102. for i := 0; i < numFields; i++ {
  103. fld = typ.Field(i)
  104. if !fld.Anonymous && len(fld.PkgPath) > 0 {
  105. continue
  106. }
  107. tag = fld.Tag.Get(v.tagName)
  108. if tag == skipValidationTag {
  109. continue
  110. }
  111. customName = fld.Name
  112. if v.hasTagNameFunc {
  113. name := v.tagNameFunc(fld)
  114. if len(name) > 0 {
  115. customName = name
  116. }
  117. }
  118. // NOTE: cannot use shared tag cache, because tags may be equal, but things like alias may be different
  119. // and so only struct level caching can be used instead of combined with Field tag caching
  120. if len(tag) > 0 {
  121. ctag, _ = v.parseFieldTagsRecursive(tag, fld.Name, "", false)
  122. } else {
  123. // even if field doesn't have validations need cTag for traversing to potential inner/nested
  124. // elements of the field.
  125. ctag = new(cTag)
  126. }
  127. cs.fields = append(cs.fields, &cField{
  128. idx: i,
  129. name: fld.Name,
  130. altName: customName,
  131. cTags: ctag,
  132. namesEqual: fld.Name == customName,
  133. })
  134. }
  135. v.structCache.Set(typ, cs)
  136. return cs
  137. }
  138. func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias string, hasAlias bool) (firstCtag *cTag, current *cTag) {
  139. var t string
  140. var ok bool
  141. noAlias := len(alias) == 0
  142. tags := strings.Split(tag, tagSeparator)
  143. for i := 0; i < len(tags); i++ {
  144. t = tags[i]
  145. if noAlias {
  146. alias = t
  147. }
  148. // check map for alias and process new tags, otherwise process as usual
  149. if tagsVal, found := v.aliases[t]; found {
  150. if i == 0 {
  151. firstCtag, current = v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
  152. } else {
  153. next, curr := v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
  154. current.next, current = next, curr
  155. }
  156. continue
  157. }
  158. var prevTag tagType
  159. if i == 0 {
  160. current = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true}
  161. firstCtag = current
  162. } else {
  163. prevTag = current.typeof
  164. current.next = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true}
  165. current = current.next
  166. }
  167. switch t {
  168. case diveTag:
  169. current.typeof = typeDive
  170. continue
  171. case keysTag:
  172. current.typeof = typeKeys
  173. if i == 0 || prevTag != typeDive {
  174. panic(fmt.Sprintf("'%s' tag must be immediately preceded by the '%s' tag", keysTag, diveTag))
  175. }
  176. current.typeof = typeKeys
  177. // need to pass along only keys tag
  178. // need to increment i to skip over the keys tags
  179. b := make([]byte, 0, 64)
  180. i++
  181. for ; i < len(tags); i++ {
  182. b = append(b, tags[i]...)
  183. b = append(b, ',')
  184. if tags[i] == endKeysTag {
  185. break
  186. }
  187. }
  188. current.keys, _ = v.parseFieldTagsRecursive(string(b[:len(b)-1]), fieldName, "", false)
  189. continue
  190. case endKeysTag:
  191. current.typeof = typeEndKeys
  192. // if there are more in tags then there was no keysTag defined
  193. // and an error should be thrown
  194. if i != len(tags)-1 {
  195. panic(keysTagNotDefined)
  196. }
  197. return
  198. case omitempty:
  199. current.typeof = typeOmitEmpty
  200. continue
  201. case structOnlyTag:
  202. current.typeof = typeStructOnly
  203. continue
  204. case noStructLevelTag:
  205. current.typeof = typeNoStructLevel
  206. continue
  207. default:
  208. if t == isdefault {
  209. current.typeof = typeIsDefault
  210. }
  211. // if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C"
  212. orVals := strings.Split(t, orSeparator)
  213. for j := 0; j < len(orVals); j++ {
  214. vals := strings.SplitN(orVals[j], tagKeySeparator, 2)
  215. if noAlias {
  216. alias = vals[0]
  217. current.aliasTag = alias
  218. } else {
  219. current.actualAliasTag = t
  220. }
  221. if j > 0 {
  222. current.next = &cTag{aliasTag: alias, actualAliasTag: current.actualAliasTag, hasAlias: hasAlias, hasTag: true}
  223. current = current.next
  224. }
  225. current.hasParam = len(vals) > 1
  226. current.tag = vals[0]
  227. if len(current.tag) == 0 {
  228. panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName)))
  229. }
  230. if current.fn, ok = v.validations[current.tag]; !ok {
  231. panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName)))
  232. }
  233. if len(orVals) > 1 {
  234. current.typeof = typeOr
  235. }
  236. if len(vals) > 1 {
  237. current.param = strings.Replace(strings.Replace(vals[1], utf8HexComma, ",", -1), utf8Pipe, "|", -1)
  238. }
  239. }
  240. current.isBlockEnd = true
  241. }
  242. }
  243. return
  244. }
  245. func (v *Validate) fetchCacheTag(tag string) *cTag {
  246. // find cached tag
  247. ctag, found := v.tagCache.Get(tag)
  248. if !found {
  249. v.tagCache.lock.Lock()
  250. defer v.tagCache.lock.Unlock()
  251. // could have been multiple trying to access, but once first is done this ensures tag
  252. // isn't parsed again.
  253. ctag, found = v.tagCache.Get(tag)
  254. if !found {
  255. ctag, _ = v.parseFieldTagsRecursive(tag, "", "", false)
  256. v.tagCache.Set(tag, ctag)
  257. }
  258. }
  259. return ctag
  260. }