pipeline.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. package pipeline
  2. import (
  3. "context"
  4. "errors"
  5. "sync"
  6. "time"
  7. "go-common/library/net/metadata"
  8. xtime "go-common/library/time"
  9. )
  10. // ErrFull channel full error
  11. var ErrFull = errors.New("channel full")
  12. type message struct {
  13. key string
  14. value interface{}
  15. }
  16. // Pipeline pipeline struct
  17. type Pipeline struct {
  18. Do func(c context.Context, index int, values map[string][]interface{})
  19. Split func(key string) int
  20. chans []chan *message
  21. mirrorChans []chan *message
  22. config *Config
  23. wait sync.WaitGroup
  24. }
  25. // Config Pipeline config
  26. type Config struct {
  27. // MaxSize merge size
  28. MaxSize int
  29. // Interval merge interval
  30. Interval xtime.Duration
  31. // Buffer channel size
  32. Buffer int
  33. // Worker channel number
  34. Worker int
  35. // Smooth smoothing interval
  36. Smooth bool
  37. }
  38. func (c *Config) fix() {
  39. if c.MaxSize <= 0 {
  40. c.MaxSize = 1000
  41. }
  42. if c.Interval <= 0 {
  43. c.Interval = xtime.Duration(time.Second)
  44. }
  45. if c.Buffer <= 0 {
  46. c.Buffer = 1000
  47. }
  48. if c.Worker <= 0 {
  49. c.Worker = 10
  50. }
  51. }
  52. // NewPipeline new pipline
  53. func NewPipeline(config *Config) (res *Pipeline) {
  54. if config == nil {
  55. config = &Config{}
  56. }
  57. config.fix()
  58. res = &Pipeline{
  59. chans: make([]chan *message, config.Worker),
  60. mirrorChans: make([]chan *message, config.Worker),
  61. config: config,
  62. }
  63. for i := 0; i < config.Worker; i++ {
  64. res.chans[i] = make(chan *message, config.Buffer)
  65. res.mirrorChans[i] = make(chan *message, config.Buffer)
  66. }
  67. return
  68. }
  69. // Start start all mergeproc
  70. func (p *Pipeline) Start() {
  71. if p.Do == nil {
  72. panic("pipeline: do func is nil")
  73. }
  74. if p.Split == nil {
  75. panic("pipeline: split func is nil")
  76. }
  77. var mirror bool
  78. p.wait.Add(len(p.chans) + len(p.mirrorChans))
  79. for i, ch := range p.chans {
  80. go p.mergeproc(mirror, i, ch)
  81. }
  82. mirror = true
  83. for i, ch := range p.mirrorChans {
  84. go p.mergeproc(mirror, i, ch)
  85. }
  86. }
  87. // SyncAdd sync add a value to channal, channel shard in split method
  88. func (p *Pipeline) SyncAdd(c context.Context, key string, value interface{}) {
  89. ch, msg := p.add(c, key, value)
  90. ch <- msg
  91. }
  92. // Add async add a value to channal, channel shard in split method
  93. func (p *Pipeline) Add(c context.Context, key string, value interface{}) (err error) {
  94. ch, msg := p.add(c, key, value)
  95. select {
  96. case ch <- msg:
  97. default:
  98. err = ErrFull
  99. }
  100. return
  101. }
  102. func (p *Pipeline) add(c context.Context, key string, value interface{}) (ch chan *message, m *message) {
  103. shard := p.Split(key) % p.config.Worker
  104. if metadata.Bool(c, metadata.Mirror) {
  105. ch = p.mirrorChans[shard]
  106. } else {
  107. ch = p.chans[shard]
  108. }
  109. m = &message{key: key, value: value}
  110. return
  111. }
  112. // Close all goroutinue
  113. func (p *Pipeline) Close() (err error) {
  114. for _, ch := range p.chans {
  115. ch <- nil
  116. }
  117. for _, ch := range p.mirrorChans {
  118. ch <- nil
  119. }
  120. p.wait.Wait()
  121. return
  122. }
  123. func (p *Pipeline) mergeproc(mirror bool, index int, ch <-chan *message) {
  124. defer p.wait.Done()
  125. var (
  126. m *message
  127. vals = make(map[string][]interface{}, p.config.MaxSize)
  128. closed bool
  129. count int
  130. inteval = p.config.Interval
  131. oldTicker = true
  132. )
  133. if p.config.Smooth && index > 0 {
  134. inteval = xtime.Duration(int64(index) * (int64(p.config.Interval) / int64(p.config.Worker)))
  135. }
  136. ticker := time.NewTicker(time.Duration(inteval))
  137. for {
  138. select {
  139. case m = <-ch:
  140. if m == nil {
  141. closed = true
  142. break
  143. }
  144. count++
  145. vals[m.key] = append(vals[m.key], m.value)
  146. if count >= p.config.MaxSize {
  147. break
  148. }
  149. continue
  150. case <-ticker.C:
  151. if p.config.Smooth && oldTicker {
  152. ticker.Stop()
  153. ticker = time.NewTicker(time.Duration(p.config.Interval))
  154. oldTicker = false
  155. }
  156. }
  157. if len(vals) > 0 {
  158. ctx := context.Background()
  159. if mirror {
  160. ctx = metadata.NewContext(ctx, metadata.MD{metadata.Mirror: true})
  161. }
  162. p.Do(ctx, index, vals)
  163. vals = make(map[string][]interface{}, p.config.MaxSize)
  164. count = 0
  165. }
  166. if closed {
  167. ticker.Stop()
  168. return
  169. }
  170. }
  171. }