task.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. package mysql
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "time"
  8. "go-common/app/admin/main/aegis/model/common"
  9. modtask "go-common/app/admin/main/aegis/model/task"
  10. xsql "go-common/library/database/sql"
  11. "go-common/library/log"
  12. "go-common/library/xstr"
  13. "github.com/pkg/errors"
  14. )
  15. const (
  16. _taskSQL = "SELECT id,business_id,flow_id,rid,admin_id,uid,state,weight,utime,gtime,mid,fans,`group`,reason,ctime,mtime from task WHERE id=?"
  17. _listCheckSQL = "SELECT id FROM task WHERE id IN (%s)"
  18. _dispatchByIDSQL = "UPDATE task SET gtime=? WHERE id=? AND state=? AND uid=? AND gtime=0"
  19. _queryGtimeSQL = "SELECT gtime FROM task WHERE id=? AND state=? AND uid=?"
  20. _dispatchSQL = "UPDATE task SET gtime=? WHERE state=? AND uid=? AND gtime='0000-00-00 00:00:00' ORDER BY weight LIMIT ?"
  21. _releaseSQL = "UPDATE task SET admin_id=0,uid=0,state=0,gtime='0000-00-00 00:00:00' WHERE business_id=? AND flow_id=? AND uid=? AND (state=? OR (state=0 AND admin_id>0))"
  22. _resetGtimeSQL = "UPDATE task SET gtime='0000-00-00 00:00:00' WHERE state=? AND business_id=? AND flow_id=? AND uid=?"
  23. _seizeSQL = "UPDATE task SET state=?,uid=? WHERE id=? AND state=?"
  24. _submitSQL = "UPDATE task SET state=?,uid=?,utime=? WHERE id=? AND state=? AND uid=?"
  25. _delaySQL = "UPDATE task SET state=?,uid=?,reason=?,gtime='0000-00-00 00:00:00' WHERE id=? AND state=? AND uid=?"
  26. _consumerSQL = "INSERT INTO task_consumer (business_id,flow_id,uid,state) VALUES (?,?,?,?) ON DUPLICATE KEY UPDATE state=?"
  27. _onlinesSQL = "SELECT uid,mtime FROM task_consumer WHERE business_id=? AND flow_id=? AND state=?"
  28. _isconsumerOnSQL = "SELECT state FROM task_consumer WHERE business_id=? AND flow_id=? AND uid=?"
  29. _queryTaskSQL = "SELECT id,business_id,flow_id,uid,weight FROM task WHERE state=? AND mtime<=? AND id>? ORDER BY id LIMIT ?"
  30. _countPersonalSQL = "SELECT count(*) FROM task WHERE state=? AND business_id=? AND flow_id=? AND uid=?"
  31. _queryForSeizeSQL = "SELECT id FROM task WHERE state=? AND business_id=? AND flow_id=? AND uid IN (0,?) ORDER BY weight DESC LIMIT ?"
  32. _listTasksSQL = "SELECT `id`,`business_id`,`flow_id`,`rid`,`admin_id`,`uid`,`state`,`weight`,`utime`,`gtime`,`mid`,`fans`,`group`,`reason`,`ctime`,`mtime` FROM task %s ORDER BY weight DESC LIMIT ?,?"
  33. )
  34. // TaskFromDB .
  35. func (d *Dao) TaskFromDB(c context.Context, id int64) (task *modtask.Task, err error) {
  36. task = &modtask.Task{}
  37. err = d.db.QueryRow(c, _taskSQL, id).
  38. Scan(&task.ID, &task.BusinessID, &task.FlowID, &task.RID, &task.AdminID, &task.UID, &task.State,
  39. &task.Weight, &task.Utime, &task.Gtime, &task.MID, &task.Fans, &task.Group, &task.Reason, &task.Ctime, &task.Mtime)
  40. if err != nil {
  41. task = nil
  42. if err == sql.ErrNoRows {
  43. log.Error("TaskFromDB(%d) norows", id)
  44. err = nil
  45. return
  46. }
  47. log.Error("TaskFromDB(%d) error(%v)", id, errors.WithStack(err))
  48. }
  49. return
  50. }
  51. // DispatchByID 派遣任务,更新gtime
  52. func (d *Dao) DispatchByID(c context.Context, mtasks map[int64]*modtask.Task, ids []int64, args ...interface{}) (missids map[int64]struct{}, err error) {
  53. var (
  54. gtime = time.Now()
  55. uid = args[0].(int64)
  56. )
  57. missids = make(map[int64]struct{})
  58. for _, id := range ids {
  59. var (
  60. rows int64
  61. gt time.Time
  62. res sql.Result
  63. )
  64. if err = d.db.QueryRow(c, _queryGtimeSQL, id, modtask.TaskStateDispatch, uid).Scan(&gt); err != nil {
  65. if err == sql.ErrNoRows {
  66. missids[id] = struct{}{}
  67. err = nil
  68. continue
  69. }
  70. log.Error("d.db.QueryRow error(%v)", errors.WithStack(err))
  71. return
  72. }
  73. if gt.IsZero() {
  74. res, err = d.db.Exec(c, _dispatchByIDSQL, gtime, id, modtask.TaskStateDispatch, uid)
  75. if err != nil {
  76. log.Error("Exec error(%v)", errors.WithStack(err))
  77. return
  78. }
  79. if rows, err = res.RowsAffected(); err != nil {
  80. log.Error("RowsAffected error(%v)", errors.WithStack(err))
  81. return
  82. }
  83. if rows == 0 {
  84. missids[id] = struct{}{}
  85. } else {
  86. mtasks[id].Gtime = common.IntTime(gtime.Unix())
  87. }
  88. } else {
  89. mtasks[id].Gtime = common.IntTime(gt.Unix())
  90. }
  91. }
  92. return
  93. }
  94. // DBDispatch 直接数据库派遣
  95. func (d *Dao) DBDispatch(c context.Context, opt *modtask.NextOptions) (tasks []*modtask.Task, count int64, err error) {
  96. var (
  97. res sql.Result
  98. gtime = time.Now()
  99. )
  100. // 1.直接更新派遣时间
  101. res, err = d.db.Exec(c, _dispatchSQL, gtime, modtask.TaskStateDispatch, opt.UID, opt.DispatchCount)
  102. if err != nil {
  103. log.Error("Exec error(%v)", errors.WithStack(err))
  104. return
  105. }
  106. if count, err = res.RowsAffected(); err != nil {
  107. log.Error("RowsAffected error(%v)", errors.WithStack(err))
  108. return
  109. }
  110. // 2.读取任务
  111. wherecache := fmt.Sprintf("WHERE state=%d AND uid=%d AND gtime!='0000-00-00 00:00:00'", modtask.TaskStateDispatch, opt.UID)
  112. return d.listTasks(c, &modtask.ListOptions{BaseOptions: opt.BaseOptions, Pager: common.Pager{Pn: 1, Ps: int(opt.DispatchCount)}}, wherecache)
  113. }
  114. // Release 释放任务
  115. func (d *Dao) Release(c context.Context, opt *common.BaseOptions, delay bool) (rows int64, err error) {
  116. sql := _releaseSQL
  117. if delay {
  118. sql = _releaseSQL + " AND gtime='0000-00-00 00:00:00'"
  119. }
  120. log.Info("Mysql Release(%+v) delay(%v)", opt, delay)
  121. res, err := d.db.Exec(c, sql, opt.BusinessID, opt.FlowID, opt.UID, modtask.TaskStateDispatch)
  122. if err != nil {
  123. log.Error("db.Exec(%s)[%d,%d,%d,%d] error(%v)", sql, opt.BusinessID, opt.FlowID, opt.UID, modtask.TaskStateDispatch, err)
  124. return
  125. }
  126. // 已经下发的延迟5分钟释放
  127. if delay {
  128. _, err = d.db.Exec(c, _resetGtimeSQL, modtask.TaskStateDispatch, opt.BusinessID, opt.FlowID, opt.UID)
  129. if err != nil {
  130. log.Error("db.Exec(%s)[%d,%d,%d,%d] error(%v)", sql, modtask.TaskStateDispatch, opt.BusinessID, opt.FlowID, opt.UID, err)
  131. }
  132. time.AfterFunc(5*time.Minute, func() {
  133. d.Release(context.Background(), opt, false)
  134. })
  135. }
  136. return res.RowsAffected()
  137. }
  138. // Seize 抢占任务
  139. func (d *Dao) Seize(c context.Context, mapids map[int64]int64) (count int64, err error) {
  140. tx, err := d.db.Begin(c)
  141. if err != nil {
  142. log.Error("d.Seize.Begin error(%v)", errors.WithStack(err))
  143. return
  144. }
  145. defer tx.Commit()
  146. for tid, uid := range mapids {
  147. var (
  148. rows int64
  149. res sql.Result
  150. )
  151. res, err = tx.Exec(_seizeSQL, modtask.TaskStateDispatch, uid, tid, modtask.TaskStateInit)
  152. if err != nil {
  153. log.Error("Exec error(%v)", errors.WithStack(err))
  154. tx.Rollback()
  155. return
  156. }
  157. if rows, err = res.RowsAffected(); err != nil {
  158. log.Error("RowsAffected error(%v)", errors.WithStack(err))
  159. tx.Rollback()
  160. return
  161. }
  162. if rows == 1 {
  163. count++
  164. }
  165. }
  166. return
  167. }
  168. // Delay 延迟任务
  169. func (d *Dao) Delay(c context.Context, opt *modtask.DelayOptions) (rows int64, err error) {
  170. var (
  171. res sql.Result
  172. )
  173. res, err = d.db.Exec(c, _delaySQL, modtask.TaskStateDelay, opt.UID, opt.Reason, opt.TaskID, modtask.TaskStateDispatch, opt.UID)
  174. if err != nil {
  175. log.Error("Exec error(%v)", errors.WithStack(err))
  176. return
  177. }
  178. if rows, err = res.RowsAffected(); err != nil {
  179. log.Error("RowsAffected error(%v)", errors.WithStack(err))
  180. return
  181. }
  182. return
  183. }
  184. // ListCheckUnSeized .
  185. func (d *Dao) ListCheckUnSeized(c context.Context, mtasks map[int64]*modtask.Task, ids []int64, args ...interface{}) (missids map[int64]struct{}, err error) {
  186. wherecase := fmt.Sprintf("state = %d", modtask.TaskStateInit)
  187. return d.listCheck(c, wherecase, ids)
  188. }
  189. // ListCheckSeized .
  190. func (d *Dao) ListCheckSeized(c context.Context, mtasks map[int64]*modtask.Task, ids []int64, args ...interface{}) (missids map[int64]struct{}, err error) {
  191. if len(args) < 1 {
  192. return
  193. }
  194. uid := args[0].(int64)
  195. wherecase := fmt.Sprintf("state = %d", modtask.TaskStateDispatch)
  196. if uid != 0 {
  197. wherecase += fmt.Sprintf(" AND uid=%d", uid)
  198. }
  199. return d.listCheck(c, wherecase, ids)
  200. }
  201. // ListCheckDelay .
  202. func (d *Dao) ListCheckDelay(c context.Context, mtasks map[int64]*modtask.Task, ids []int64, args ...interface{}) (missids map[int64]struct{}, err error) {
  203. if len(args) < 1 {
  204. return
  205. }
  206. uid := args[0].(int64)
  207. wherecase := fmt.Sprintf("state=%d", modtask.TaskStateDelay)
  208. if uid != 0 {
  209. wherecase += fmt.Sprintf(" AND uid=%d", uid)
  210. }
  211. return d.listCheck(c, wherecase, ids)
  212. }
  213. // ListTasks .
  214. func (d *Dao) ListTasks(c context.Context, opt *modtask.ListOptions) (tasks []*modtask.Task, count int64, err error) {
  215. var (
  216. wherecase string
  217. cases []string
  218. state int8
  219. isDefault bool
  220. )
  221. switch opt.State {
  222. case 1:
  223. state = modtask.TaskStateInit
  224. case 2:
  225. state = modtask.TaskStateDispatch
  226. case 3:
  227. state = modtask.TaskStateDelay
  228. case 4:
  229. state = modtask.TaskStateDispatch
  230. cases = append(cases, "admin_id>0")
  231. default:
  232. isDefault = true
  233. cases = append(cases, fmt.Sprintf("state<%d", modtask.TaskStateSubmit))
  234. }
  235. if !isDefault {
  236. cases = append(cases, fmt.Sprintf("state=%d", state))
  237. if !opt.BisLeader && (opt.State == 2 || opt.State == 3 || opt.State == 4) {
  238. cases = append(cases, fmt.Sprintf("uid=%d", opt.UID))
  239. }
  240. }
  241. wherecase = fmt.Sprintf("WHERE business_id=%d AND flow_id=%d AND ", opt.BusinessID, opt.FlowID) + strings.Join(cases, " AND ")
  242. return d.listTasks(c, opt, wherecase)
  243. }
  244. func (d *Dao) listTasks(c context.Context, opt *modtask.ListOptions, wherecase string) (tasks []*modtask.Task, count int64, err error) {
  245. countSQL := fmt.Sprintf("SELECT count(*) FROM task %s", wherecase)
  246. if err = d.db.QueryRow(c, countSQL).Scan(&count); err != nil {
  247. log.Error("QueryRow error(%v)", err)
  248. return
  249. }
  250. if count > 0 {
  251. var (
  252. rows *xsql.Rows
  253. listSQL = fmt.Sprintf(_listTasksSQL, wherecase)
  254. )
  255. if rows, err = d.db.Query(c, listSQL, (opt.Pn-1)*opt.Ps, opt.Pn*opt.Ps); err != nil {
  256. log.Error("Query error(%v)", err)
  257. return
  258. }
  259. defer rows.Close()
  260. for rows.Next() {
  261. task := &modtask.Task{}
  262. if err = rows.Scan(&task.ID, &task.BusinessID, &task.FlowID, &task.RID, &task.AdminID, &task.UID, &task.State,
  263. &task.Weight, &task.Utime, &task.Gtime, &task.MID, &task.Fans, &task.Group, &task.Reason, &task.Ctime, &task.Mtime); err != nil {
  264. log.Error("Scan error(%v)", err)
  265. return
  266. }
  267. tasks = append(tasks, task)
  268. }
  269. }
  270. return
  271. }
  272. func (d *Dao) listCheck(c context.Context, wherecase string, ids []int64) (missids map[int64]struct{}, err error) {
  273. if len(ids) == 0 {
  274. return
  275. }
  276. missids = make(map[int64]struct{})
  277. mapids := make(map[int64]struct{})
  278. log.Info("listCheck ids(%v)", ids)
  279. defer func() {
  280. log.Info("listCheck missids(%v)", missids)
  281. }()
  282. for _, id := range ids {
  283. mapids[id] = struct{}{}
  284. }
  285. var (
  286. rows *xsql.Rows
  287. sqlstring = fmt.Sprintf(_listCheckSQL, xstr.JoinInts(ids)) + " AND " + wherecase
  288. )
  289. if rows, err = d.db.Query(c, sqlstring); err != nil {
  290. log.Error("db.Query(%s) error(%v)", sqlstring, errors.WithStack(err))
  291. return
  292. }
  293. defer rows.Close()
  294. for rows.Next() {
  295. var id int64
  296. if err = rows.Scan(&id); err != nil {
  297. log.Error("rows.Scan error(%v)", errors.WithStack(err))
  298. return
  299. }
  300. delete(mapids, id)
  301. }
  302. for id := range mapids {
  303. missids[id] = struct{}{}
  304. }
  305. return
  306. }
  307. // ConsumerOn .
  308. func (d *Dao) ConsumerOn(c context.Context, opt *common.BaseOptions) (err error) {
  309. return d.consumer(c, opt, modtask.ActionConsumerOn)
  310. }
  311. // ConsumerOff .
  312. func (d *Dao) ConsumerOff(c context.Context, opt *common.BaseOptions) (err error) {
  313. return d.consumer(c, opt, modtask.ActionConsumerOff)
  314. }
  315. // IsConsumerOn .
  316. func (d *Dao) IsConsumerOn(c context.Context, opt *common.BaseOptions) (on bool, err error) {
  317. var state int8
  318. if err = d.db.QueryRow(c, _isconsumerOnSQL, opt.BusinessID, opt.FlowID, opt.UID).Scan(&state); err != nil {
  319. if err == sql.ErrNoRows {
  320. err = nil
  321. return
  322. }
  323. log.Error("d.db.QueryRow error(%v)", err)
  324. return
  325. }
  326. if state == modtask.ActionConsumerOn {
  327. on = true
  328. }
  329. return
  330. }
  331. func (d *Dao) consumer(c context.Context, opt *common.BaseOptions, action int8) (err error) {
  332. var (
  333. res sql.Result
  334. )
  335. res, err = d.db.Exec(c, _consumerSQL, opt.BusinessID, opt.FlowID, opt.UID, action, action)
  336. if err != nil {
  337. log.Error("Exec error(%v)", errors.WithStack(err))
  338. return
  339. }
  340. if _, err = res.RowsAffected(); err != nil {
  341. log.Error("RowsAffected error(%v)", errors.WithStack(err))
  342. return
  343. }
  344. return
  345. }
  346. // ConsumerStat 24小时内有活动或者在线的用户
  347. func (d *Dao) ConsumerStat(c context.Context, bizid, flowid int64) (items []*modtask.WatchItem, err error) {
  348. var rows *xsql.Rows
  349. sql := "SELECT uid,mtime,state from task_consumer where business_id=? AND flow_id=? AND (mtime > ? or state=1) order by mtime desc"
  350. if rows, err = d.db.Query(c, sql, bizid, flowid, time.Now().Add(-24*time.Hour)); err != nil {
  351. log.Error("ConsumerStat error(%v)", err)
  352. return
  353. }
  354. defer rows.Close()
  355. for rows.Next() {
  356. item := &modtask.WatchItem{}
  357. if err = rows.Scan(&item.UID, &item.Mtime, &item.State); err != nil {
  358. log.Error("ConsumerStat error(%v)", err)
  359. return
  360. }
  361. items = append(items, item)
  362. }
  363. return
  364. }
  365. // Onlines 在线列表
  366. func (d *Dao) Onlines(c context.Context, opt *common.BaseOptions) (uids map[int64]time.Time, err error) {
  367. var (
  368. rows *xsql.Rows
  369. )
  370. rows, err = d.db.Query(c, _onlinesSQL, opt.BusinessID, opt.FlowID, modtask.ActionConsumerOn)
  371. if err != nil {
  372. log.Error("db.Query error(%v)", err)
  373. return
  374. }
  375. defer rows.Close()
  376. uids = make(map[int64]time.Time)
  377. for rows.Next() {
  378. var (
  379. uid int64
  380. mtime time.Time
  381. )
  382. if err = rows.Scan(&uid, &mtime); err != nil {
  383. log.Error("rows.Scan error(%v)", err)
  384. return
  385. }
  386. uids[uid] = mtime
  387. }
  388. return
  389. }
  390. // QueryTask .
  391. func (d *Dao) QueryTask(c context.Context, state int8, mtime time.Time, id, limit int64) (tasks []*modtask.Task, lastid int64, err error) {
  392. var rows *xsql.Rows
  393. rows, err = d.db.Query(c, _queryTaskSQL, state, mtime, id, limit)
  394. if err != nil {
  395. log.Error("db.Query error(%v)", err)
  396. return
  397. }
  398. defer rows.Close()
  399. for rows.Next() {
  400. task := &modtask.Task{}
  401. if err = rows.Scan(&task.ID, &task.BusinessID, &task.FlowID, &task.UID, &task.Weight); err != nil {
  402. log.Error("rows.Scan error(%v)", err)
  403. return
  404. }
  405. tasks = append(tasks, task)
  406. lastid = task.ID
  407. }
  408. return
  409. }
  410. // CountPersonal count personal task
  411. func (d *Dao) CountPersonal(c context.Context, opt *common.BaseOptions) (count int64, err error) {
  412. if err = d.db.QueryRow(c, _countPersonalSQL, modtask.TaskStateDispatch, opt.BusinessID, opt.FlowID, opt.UID).Scan(&count); err != nil {
  413. log.Error("QueryRow error(%v)", errors.WithStack(err))
  414. return
  415. }
  416. return
  417. }
  418. // QueryForSeize 查询当前可抢占的任务
  419. func (d *Dao) QueryForSeize(c context.Context, businessID, flowID, uid, seizecount int64) (hitids []int64, err error) {
  420. log.Info("task-QueryForSeize businessID(%d), flowID(%d), uid(%d), seizecount(%d)", businessID, flowID, uid, seizecount)
  421. defer func() { log.Info("task-QueryForSeize hitids(%v), err(%v)", hitids, err) }()
  422. var rows *xsql.Rows
  423. rows, err = d.db.Query(c, _queryForSeizeSQL, modtask.TaskStateInit, businessID, flowID, uid, seizecount)
  424. if err != nil {
  425. log.Error("db.Query error(%v)", err)
  426. return
  427. }
  428. defer rows.Close()
  429. for rows.Next() {
  430. var id int64
  431. if err = rows.Scan(&id); err != nil {
  432. log.Error("rows.Scan error(%v)", err)
  433. return
  434. }
  435. hitids = append(hitids, id)
  436. }
  437. return
  438. }