callback_update.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. package gorm
  2. import (
  3. "errors"
  4. "fmt"
  5. "strings"
  6. )
  7. // Define callbacks for updating
  8. func init() {
  9. DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
  10. DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
  11. DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
  12. DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
  13. DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
  14. DefaultCallback.Update().Register("gorm:update", updateCallback)
  15. DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
  16. DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
  17. DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
  18. }
  19. // assignUpdatingAttributesCallback assign updating attributes to model
  20. func assignUpdatingAttributesCallback(scope *Scope) {
  21. if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
  22. if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
  23. scope.InstanceSet("gorm:update_attrs", updateMaps)
  24. } else {
  25. scope.SkipLeft()
  26. }
  27. }
  28. }
  29. // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
  30. func beforeUpdateCallback(scope *Scope) {
  31. if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
  32. scope.Err(errors.New("Missing WHERE clause while updating"))
  33. return
  34. }
  35. if _, ok := scope.Get("gorm:update_column"); !ok {
  36. if !scope.HasError() {
  37. scope.CallMethod("BeforeSave")
  38. }
  39. if !scope.HasError() {
  40. scope.CallMethod("BeforeUpdate")
  41. }
  42. }
  43. }
  44. // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
  45. func updateTimeStampForUpdateCallback(scope *Scope) {
  46. if _, ok := scope.Get("gorm:update_column"); !ok {
  47. scope.SetColumn("UpdatedAt", NowFunc())
  48. }
  49. }
  50. // updateCallback the callback used to update data to database
  51. func updateCallback(scope *Scope) {
  52. if !scope.HasError() {
  53. var sqls []string
  54. if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
  55. for column, value := range updateAttrs.(map[string]interface{}) {
  56. sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
  57. }
  58. } else {
  59. for _, field := range scope.Fields() {
  60. if scope.changeableField(field) {
  61. if !field.IsPrimaryKey && field.IsNormal {
  62. sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
  63. } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
  64. for _, foreignKey := range relationship.ForeignDBNames {
  65. if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
  66. sqls = append(sqls,
  67. fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
  68. }
  69. }
  70. }
  71. }
  72. }
  73. }
  74. var extraOption string
  75. if str, ok := scope.Get("gorm:update_option"); ok {
  76. extraOption = fmt.Sprint(str)
  77. }
  78. if len(sqls) > 0 {
  79. scope.Raw(fmt.Sprintf(
  80. "UPDATE %v SET %v%v%v",
  81. scope.QuotedTableName(),
  82. strings.Join(sqls, ", "),
  83. addExtraSpaceIfExist(scope.CombinedConditionSql()),
  84. addExtraSpaceIfExist(extraOption),
  85. )).Exec()
  86. }
  87. }
  88. }
  89. // afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
  90. func afterUpdateCallback(scope *Scope) {
  91. if _, ok := scope.Get("gorm:update_column"); !ok {
  92. if !scope.HasError() {
  93. scope.CallMethod("AfterUpdate")
  94. }
  95. if !scope.HasError() {
  96. scope.CallMethod("AfterSave")
  97. }
  98. }
  99. }