123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 |
- package gorm
- import (
- "errors"
- "fmt"
- "reflect"
- "strconv"
- "strings"
- )
- // preloadCallback used to preload associations
- func preloadCallback(scope *Scope) {
- if _, ok := scope.Get("gorm:auto_preload"); ok {
- autoPreload(scope)
- }
- if scope.Search.preload == nil || scope.HasError() {
- return
- }
- var (
- preloadedMap = map[string]bool{}
- fields = scope.Fields()
- )
- for _, preload := range scope.Search.preload {
- var (
- preloadFields = strings.Split(preload.schema, ".")
- currentScope = scope
- currentFields = fields
- )
- for idx, preloadField := range preloadFields {
- var currentPreloadConditions []interface{}
- if currentScope == nil {
- continue
- }
- // if not preloaded
- if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
- // assign search conditions to last preload
- if idx == len(preloadFields)-1 {
- currentPreloadConditions = preload.conditions
- }
- for _, field := range currentFields {
- if field.Name != preloadField || field.Relationship == nil {
- continue
- }
- switch field.Relationship.Kind {
- case "has_one":
- currentScope.handleHasOnePreload(field, currentPreloadConditions)
- case "has_many":
- currentScope.handleHasManyPreload(field, currentPreloadConditions)
- case "belongs_to":
- currentScope.handleBelongsToPreload(field, currentPreloadConditions)
- case "many_to_many":
- currentScope.handleManyToManyPreload(field, currentPreloadConditions)
- default:
- scope.Err(errors.New("unsupported relation"))
- }
- preloadedMap[preloadKey] = true
- break
- }
- if !preloadedMap[preloadKey] {
- scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
- return
- }
- }
- // preload next level
- if idx < len(preloadFields)-1 {
- currentScope = currentScope.getColumnAsScope(preloadField)
- if currentScope != nil {
- currentFields = currentScope.Fields()
- }
- }
- }
- }
- }
- func autoPreload(scope *Scope) {
- for _, field := range scope.Fields() {
- if field.Relationship == nil {
- continue
- }
- if val, ok := field.TagSettings["PRELOAD"]; ok {
- if preload, err := strconv.ParseBool(val); err != nil {
- scope.Err(errors.New("invalid preload option"))
- return
- } else if !preload {
- continue
- }
- }
- scope.Search.Preload(field.Name)
- }
- }
- func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
- var (
- preloadDB = scope.NewDB()
- preloadConditions []interface{}
- )
- for _, condition := range conditions {
- if scopes, ok := condition.(func(*DB) *DB); ok {
- preloadDB = scopes(preloadDB)
- } else {
- preloadConditions = append(preloadConditions, condition)
- }
- }
- return preloadDB, preloadConditions
- }
- // handleHasOnePreload used to preload has one associations
- func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
- relation := field.Relationship
- // get relations's primary keys
- primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
- if len(primaryKeys) == 0 {
- return
- }
- // preload conditions
- preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
- // find relations
- query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
- values := toQueryValues(primaryKeys)
- if relation.PolymorphicType != "" {
- query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
- values = append(values, relation.PolymorphicValue)
- }
- results := makeSlice(field.Struct.Type)
- scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
- // assign find results
- var (
- resultsValue = indirect(reflect.ValueOf(results))
- indirectScopeValue = scope.IndirectValue()
- )
- if indirectScopeValue.Kind() == reflect.Slice {
- for j := 0; j < indirectScopeValue.Len(); j++ {
- for i := 0; i < resultsValue.Len(); i++ {
- result := resultsValue.Index(i)
- foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
- if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
- indirectValue.FieldByName(field.Name).Set(result)
- break
- }
- }
- }
- } else {
- for i := 0; i < resultsValue.Len(); i++ {
- result := resultsValue.Index(i)
- scope.Err(field.Set(result))
- }
- }
- }
- // handleHasManyPreload used to preload has many associations
- func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
- relation := field.Relationship
- // get relations's primary keys
- primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
- if len(primaryKeys) == 0 {
- return
- }
- // preload conditions
- preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
- // find relations
- query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
- values := toQueryValues(primaryKeys)
- if relation.PolymorphicType != "" {
- query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
- values = append(values, relation.PolymorphicValue)
- }
- results := makeSlice(field.Struct.Type)
- scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
- // assign find results
- var (
- resultsValue = indirect(reflect.ValueOf(results))
- indirectScopeValue = scope.IndirectValue()
- )
- if indirectScopeValue.Kind() == reflect.Slice {
- preloadMap := make(map[string][]reflect.Value)
- for i := 0; i < resultsValue.Len(); i++ {
- result := resultsValue.Index(i)
- foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
- preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
- }
- for j := 0; j < indirectScopeValue.Len(); j++ {
- object := indirect(indirectScopeValue.Index(j))
- objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
- f := object.FieldByName(field.Name)
- if results, ok := preloadMap[toString(objectRealValue)]; ok {
- f.Set(reflect.Append(f, results...))
- } else {
- f.Set(reflect.MakeSlice(f.Type(), 0, 0))
- }
- }
- } else {
- scope.Err(field.Set(resultsValue))
- }
- }
- // handleBelongsToPreload used to preload belongs to associations
- func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
- relation := field.Relationship
- // preload conditions
- preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
- // get relations's primary keys
- primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
- if len(primaryKeys) == 0 {
- return
- }
- // find relations
- results := makeSlice(field.Struct.Type)
- scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
- // assign find results
- var (
- resultsValue = indirect(reflect.ValueOf(results))
- indirectScopeValue = scope.IndirectValue()
- )
- for i := 0; i < resultsValue.Len(); i++ {
- result := resultsValue.Index(i)
- if indirectScopeValue.Kind() == reflect.Slice {
- value := getValueFromFields(result, relation.AssociationForeignFieldNames)
- for j := 0; j < indirectScopeValue.Len(); j++ {
- object := indirect(indirectScopeValue.Index(j))
- if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
- object.FieldByName(field.Name).Set(result)
- }
- }
- } else {
- scope.Err(field.Set(result))
- }
- }
- }
- // handleManyToManyPreload used to preload many to many associations
- func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
- var (
- relation = field.Relationship
- joinTableHandler = relation.JoinTableHandler
- fieldType = field.Struct.Type.Elem()
- foreignKeyValue interface{}
- foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
- linkHash = map[string][]reflect.Value{}
- isPtr bool
- )
- if fieldType.Kind() == reflect.Ptr {
- isPtr = true
- fieldType = fieldType.Elem()
- }
- var sourceKeys = []string{}
- for _, key := range joinTableHandler.SourceForeignKeys() {
- sourceKeys = append(sourceKeys, key.DBName)
- }
- // preload conditions
- preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
- // generate query with join table
- newScope := scope.New(reflect.New(fieldType).Interface())
- preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
- if len(preloadDB.search.selects) == 0 {
- preloadDB = preloadDB.Select("*")
- }
- preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
- // preload inline conditions
- if len(preloadConditions) > 0 {
- preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
- }
- rows, err := preloadDB.Rows()
- if scope.Err(err) != nil {
- return
- }
- defer rows.Close()
- columns, _ := rows.Columns()
- for rows.Next() {
- var (
- elem = reflect.New(fieldType).Elem()
- fields = scope.New(elem.Addr().Interface()).Fields()
- )
- // register foreign keys in join tables
- var joinTableFields []*Field
- for _, sourceKey := range sourceKeys {
- joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
- }
- scope.scan(rows, columns, append(fields, joinTableFields...))
- var foreignKeys = make([]interface{}, len(sourceKeys))
- // generate hashed forkey keys in join table
- for idx, joinTableField := range joinTableFields {
- if !joinTableField.Field.IsNil() {
- foreignKeys[idx] = joinTableField.Field.Elem().Interface()
- }
- }
- hashedSourceKeys := toString(foreignKeys)
- if isPtr {
- linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
- } else {
- linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
- }
- }
- if err := rows.Err(); err != nil {
- scope.Err(err)
- }
- // assign find results
- var (
- indirectScopeValue = scope.IndirectValue()
- fieldsSourceMap = map[string][]reflect.Value{}
- foreignFieldNames = []string{}
- )
- for _, dbName := range relation.ForeignFieldNames {
- if field, ok := scope.FieldByName(dbName); ok {
- foreignFieldNames = append(foreignFieldNames, field.Name)
- }
- }
- if indirectScopeValue.Kind() == reflect.Slice {
- for j := 0; j < indirectScopeValue.Len(); j++ {
- object := indirect(indirectScopeValue.Index(j))
- key := toString(getValueFromFields(object, foreignFieldNames))
- fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
- }
- } else if indirectScopeValue.IsValid() {
- key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
- fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
- }
- for source, link := range linkHash {
- for i, field := range fieldsSourceMap[source] {
- //If not 0 this means Value is a pointer and we already added preloaded models to it
- if fieldsSourceMap[source][i].Len() != 0 {
- continue
- }
- field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
- }
- }
- }
|