callback_query_preload.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. package gorm
  2. import (
  3. "errors"
  4. "fmt"
  5. "reflect"
  6. "strconv"
  7. "strings"
  8. )
  9. // preloadCallback used to preload associations
  10. func preloadCallback(scope *Scope) {
  11. if _, ok := scope.Get("gorm:auto_preload"); ok {
  12. autoPreload(scope)
  13. }
  14. if scope.Search.preload == nil || scope.HasError() {
  15. return
  16. }
  17. var (
  18. preloadedMap = map[string]bool{}
  19. fields = scope.Fields()
  20. )
  21. for _, preload := range scope.Search.preload {
  22. var (
  23. preloadFields = strings.Split(preload.schema, ".")
  24. currentScope = scope
  25. currentFields = fields
  26. )
  27. for idx, preloadField := range preloadFields {
  28. var currentPreloadConditions []interface{}
  29. if currentScope == nil {
  30. continue
  31. }
  32. // if not preloaded
  33. if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
  34. // assign search conditions to last preload
  35. if idx == len(preloadFields)-1 {
  36. currentPreloadConditions = preload.conditions
  37. }
  38. for _, field := range currentFields {
  39. if field.Name != preloadField || field.Relationship == nil {
  40. continue
  41. }
  42. switch field.Relationship.Kind {
  43. case "has_one":
  44. currentScope.handleHasOnePreload(field, currentPreloadConditions)
  45. case "has_many":
  46. currentScope.handleHasManyPreload(field, currentPreloadConditions)
  47. case "belongs_to":
  48. currentScope.handleBelongsToPreload(field, currentPreloadConditions)
  49. case "many_to_many":
  50. currentScope.handleManyToManyPreload(field, currentPreloadConditions)
  51. default:
  52. scope.Err(errors.New("unsupported relation"))
  53. }
  54. preloadedMap[preloadKey] = true
  55. break
  56. }
  57. if !preloadedMap[preloadKey] {
  58. scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
  59. return
  60. }
  61. }
  62. // preload next level
  63. if idx < len(preloadFields)-1 {
  64. currentScope = currentScope.getColumnAsScope(preloadField)
  65. if currentScope != nil {
  66. currentFields = currentScope.Fields()
  67. }
  68. }
  69. }
  70. }
  71. }
  72. func autoPreload(scope *Scope) {
  73. for _, field := range scope.Fields() {
  74. if field.Relationship == nil {
  75. continue
  76. }
  77. if val, ok := field.TagSettings["PRELOAD"]; ok {
  78. if preload, err := strconv.ParseBool(val); err != nil {
  79. scope.Err(errors.New("invalid preload option"))
  80. return
  81. } else if !preload {
  82. continue
  83. }
  84. }
  85. scope.Search.Preload(field.Name)
  86. }
  87. }
  88. func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
  89. var (
  90. preloadDB = scope.NewDB()
  91. preloadConditions []interface{}
  92. )
  93. for _, condition := range conditions {
  94. if scopes, ok := condition.(func(*DB) *DB); ok {
  95. preloadDB = scopes(preloadDB)
  96. } else {
  97. preloadConditions = append(preloadConditions, condition)
  98. }
  99. }
  100. return preloadDB, preloadConditions
  101. }
  102. // handleHasOnePreload used to preload has one associations
  103. func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
  104. relation := field.Relationship
  105. // get relations's primary keys
  106. primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
  107. if len(primaryKeys) == 0 {
  108. return
  109. }
  110. // preload conditions
  111. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  112. // find relations
  113. query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
  114. values := toQueryValues(primaryKeys)
  115. if relation.PolymorphicType != "" {
  116. query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
  117. values = append(values, relation.PolymorphicValue)
  118. }
  119. results := makeSlice(field.Struct.Type)
  120. scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
  121. // assign find results
  122. var (
  123. resultsValue = indirect(reflect.ValueOf(results))
  124. indirectScopeValue = scope.IndirectValue()
  125. )
  126. if indirectScopeValue.Kind() == reflect.Slice {
  127. for j := 0; j < indirectScopeValue.Len(); j++ {
  128. for i := 0; i < resultsValue.Len(); i++ {
  129. result := resultsValue.Index(i)
  130. foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
  131. if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
  132. indirectValue.FieldByName(field.Name).Set(result)
  133. break
  134. }
  135. }
  136. }
  137. } else {
  138. for i := 0; i < resultsValue.Len(); i++ {
  139. result := resultsValue.Index(i)
  140. scope.Err(field.Set(result))
  141. }
  142. }
  143. }
  144. // handleHasManyPreload used to preload has many associations
  145. func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
  146. relation := field.Relationship
  147. // get relations's primary keys
  148. primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
  149. if len(primaryKeys) == 0 {
  150. return
  151. }
  152. // preload conditions
  153. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  154. // find relations
  155. query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
  156. values := toQueryValues(primaryKeys)
  157. if relation.PolymorphicType != "" {
  158. query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
  159. values = append(values, relation.PolymorphicValue)
  160. }
  161. results := makeSlice(field.Struct.Type)
  162. scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
  163. // assign find results
  164. var (
  165. resultsValue = indirect(reflect.ValueOf(results))
  166. indirectScopeValue = scope.IndirectValue()
  167. )
  168. if indirectScopeValue.Kind() == reflect.Slice {
  169. preloadMap := make(map[string][]reflect.Value)
  170. for i := 0; i < resultsValue.Len(); i++ {
  171. result := resultsValue.Index(i)
  172. foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
  173. preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
  174. }
  175. for j := 0; j < indirectScopeValue.Len(); j++ {
  176. object := indirect(indirectScopeValue.Index(j))
  177. objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
  178. f := object.FieldByName(field.Name)
  179. if results, ok := preloadMap[toString(objectRealValue)]; ok {
  180. f.Set(reflect.Append(f, results...))
  181. } else {
  182. f.Set(reflect.MakeSlice(f.Type(), 0, 0))
  183. }
  184. }
  185. } else {
  186. scope.Err(field.Set(resultsValue))
  187. }
  188. }
  189. // handleBelongsToPreload used to preload belongs to associations
  190. func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
  191. relation := field.Relationship
  192. // preload conditions
  193. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  194. // get relations's primary keys
  195. primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
  196. if len(primaryKeys) == 0 {
  197. return
  198. }
  199. // find relations
  200. results := makeSlice(field.Struct.Type)
  201. scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
  202. // assign find results
  203. var (
  204. resultsValue = indirect(reflect.ValueOf(results))
  205. indirectScopeValue = scope.IndirectValue()
  206. )
  207. for i := 0; i < resultsValue.Len(); i++ {
  208. result := resultsValue.Index(i)
  209. if indirectScopeValue.Kind() == reflect.Slice {
  210. value := getValueFromFields(result, relation.AssociationForeignFieldNames)
  211. for j := 0; j < indirectScopeValue.Len(); j++ {
  212. object := indirect(indirectScopeValue.Index(j))
  213. if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
  214. object.FieldByName(field.Name).Set(result)
  215. }
  216. }
  217. } else {
  218. scope.Err(field.Set(result))
  219. }
  220. }
  221. }
  222. // handleManyToManyPreload used to preload many to many associations
  223. func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
  224. var (
  225. relation = field.Relationship
  226. joinTableHandler = relation.JoinTableHandler
  227. fieldType = field.Struct.Type.Elem()
  228. foreignKeyValue interface{}
  229. foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
  230. linkHash = map[string][]reflect.Value{}
  231. isPtr bool
  232. )
  233. if fieldType.Kind() == reflect.Ptr {
  234. isPtr = true
  235. fieldType = fieldType.Elem()
  236. }
  237. var sourceKeys = []string{}
  238. for _, key := range joinTableHandler.SourceForeignKeys() {
  239. sourceKeys = append(sourceKeys, key.DBName)
  240. }
  241. // preload conditions
  242. preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
  243. // generate query with join table
  244. newScope := scope.New(reflect.New(fieldType).Interface())
  245. preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
  246. if len(preloadDB.search.selects) == 0 {
  247. preloadDB = preloadDB.Select("*")
  248. }
  249. preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
  250. // preload inline conditions
  251. if len(preloadConditions) > 0 {
  252. preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
  253. }
  254. rows, err := preloadDB.Rows()
  255. if scope.Err(err) != nil {
  256. return
  257. }
  258. defer rows.Close()
  259. columns, _ := rows.Columns()
  260. for rows.Next() {
  261. var (
  262. elem = reflect.New(fieldType).Elem()
  263. fields = scope.New(elem.Addr().Interface()).Fields()
  264. )
  265. // register foreign keys in join tables
  266. var joinTableFields []*Field
  267. for _, sourceKey := range sourceKeys {
  268. joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
  269. }
  270. scope.scan(rows, columns, append(fields, joinTableFields...))
  271. var foreignKeys = make([]interface{}, len(sourceKeys))
  272. // generate hashed forkey keys in join table
  273. for idx, joinTableField := range joinTableFields {
  274. if !joinTableField.Field.IsNil() {
  275. foreignKeys[idx] = joinTableField.Field.Elem().Interface()
  276. }
  277. }
  278. hashedSourceKeys := toString(foreignKeys)
  279. if isPtr {
  280. linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
  281. } else {
  282. linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
  283. }
  284. }
  285. if err := rows.Err(); err != nil {
  286. scope.Err(err)
  287. }
  288. // assign find results
  289. var (
  290. indirectScopeValue = scope.IndirectValue()
  291. fieldsSourceMap = map[string][]reflect.Value{}
  292. foreignFieldNames = []string{}
  293. )
  294. for _, dbName := range relation.ForeignFieldNames {
  295. if field, ok := scope.FieldByName(dbName); ok {
  296. foreignFieldNames = append(foreignFieldNames, field.Name)
  297. }
  298. }
  299. if indirectScopeValue.Kind() == reflect.Slice {
  300. for j := 0; j < indirectScopeValue.Len(); j++ {
  301. object := indirect(indirectScopeValue.Index(j))
  302. key := toString(getValueFromFields(object, foreignFieldNames))
  303. fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
  304. }
  305. } else if indirectScopeValue.IsValid() {
  306. key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
  307. fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
  308. }
  309. for source, link := range linkHash {
  310. for i, field := range fieldsSourceMap[source] {
  311. //If not 0 this means Value is a pointer and we already added preloaded models to it
  312. if fieldsSourceMap[source][i].Len() != 0 {
  313. continue
  314. }
  315. field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
  316. }
  317. }
  318. }