sessions.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. // Copyright 2012 The Gorilla Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package sessions
  5. import (
  6. "encoding/gob"
  7. "fmt"
  8. "net/http"
  9. "time"
  10. "github.com/gorilla/context"
  11. )
  12. // Default flashes key.
  13. const flashesKey = "_flash"
  14. // Session --------------------------------------------------------------------
  15. // NewSession is called by session stores to create a new session instance.
  16. func NewSession(store Store, name string) *Session {
  17. return &Session{
  18. Values: make(map[interface{}]interface{}),
  19. store: store,
  20. name: name,
  21. Options: new(Options),
  22. }
  23. }
  24. // Session stores the values and optional configuration for a session.
  25. type Session struct {
  26. // The ID of the session, generated by stores. It should not be used for
  27. // user data.
  28. ID string
  29. // Values contains the user-data for the session.
  30. Values map[interface{}]interface{}
  31. Options *Options
  32. IsNew bool
  33. store Store
  34. name string
  35. }
  36. // Flashes returns a slice of flash messages from the session.
  37. //
  38. // A single variadic argument is accepted, and it is optional: it defines
  39. // the flash key. If not defined "_flash" is used by default.
  40. func (s *Session) Flashes(vars ...string) []interface{} {
  41. var flashes []interface{}
  42. key := flashesKey
  43. if len(vars) > 0 {
  44. key = vars[0]
  45. }
  46. if v, ok := s.Values[key]; ok {
  47. // Drop the flashes and return it.
  48. delete(s.Values, key)
  49. flashes = v.([]interface{})
  50. }
  51. return flashes
  52. }
  53. // AddFlash adds a flash message to the session.
  54. //
  55. // A single variadic argument is accepted, and it is optional: it defines
  56. // the flash key. If not defined "_flash" is used by default.
  57. func (s *Session) AddFlash(value interface{}, vars ...string) {
  58. key := flashesKey
  59. if len(vars) > 0 {
  60. key = vars[0]
  61. }
  62. var flashes []interface{}
  63. if v, ok := s.Values[key]; ok {
  64. flashes = v.([]interface{})
  65. }
  66. s.Values[key] = append(flashes, value)
  67. }
  68. // Save is a convenience method to save this session. It is the same as calling
  69. // store.Save(request, response, session). You should call Save before writing to
  70. // the response or returning from the handler.
  71. func (s *Session) Save(r *http.Request, w http.ResponseWriter) error {
  72. return s.store.Save(r, w, s)
  73. }
  74. // Name returns the name used to register the session.
  75. func (s *Session) Name() string {
  76. return s.name
  77. }
  78. // Store returns the session store used to register the session.
  79. func (s *Session) Store() Store {
  80. return s.store
  81. }
  82. // Registry -------------------------------------------------------------------
  83. // sessionInfo stores a session tracked by the registry.
  84. type sessionInfo struct {
  85. s *Session
  86. e error
  87. }
  88. // contextKey is the type used to store the registry in the context.
  89. type contextKey int
  90. // registryKey is the key used to store the registry in the context.
  91. const registryKey contextKey = 0
  92. // GetRegistry returns a registry instance for the current request.
  93. func GetRegistry(r *http.Request) *Registry {
  94. registry := context.Get(r, registryKey)
  95. if registry != nil {
  96. return registry.(*Registry)
  97. }
  98. newRegistry := &Registry{
  99. request: r,
  100. sessions: make(map[string]sessionInfo),
  101. }
  102. context.Set(r, registryKey, newRegistry)
  103. return newRegistry
  104. }
  105. // Registry stores sessions used during a request.
  106. type Registry struct {
  107. request *http.Request
  108. sessions map[string]sessionInfo
  109. }
  110. // Get registers and returns a session for the given name and session store.
  111. //
  112. // It returns a new session if there are no sessions registered for the name.
  113. func (s *Registry) Get(store Store, name string) (session *Session, err error) {
  114. if !isCookieNameValid(name) {
  115. return nil, fmt.Errorf("sessions: invalid character in cookie name: %s", name)
  116. }
  117. if info, ok := s.sessions[name]; ok {
  118. session, err = info.s, info.e
  119. } else {
  120. session, err = store.New(s.request, name)
  121. session.name = name
  122. s.sessions[name] = sessionInfo{s: session, e: err}
  123. }
  124. session.store = store
  125. return
  126. }
  127. // Save saves all sessions registered for the current request.
  128. func (s *Registry) Save(w http.ResponseWriter) error {
  129. var errMulti MultiError
  130. for name, info := range s.sessions {
  131. session := info.s
  132. if session.store == nil {
  133. errMulti = append(errMulti, fmt.Errorf(
  134. "sessions: missing store for session %q", name))
  135. } else if err := session.store.Save(s.request, w, session); err != nil {
  136. errMulti = append(errMulti, fmt.Errorf(
  137. "sessions: error saving session %q -- %v", name, err))
  138. }
  139. }
  140. if errMulti != nil {
  141. return errMulti
  142. }
  143. return nil
  144. }
  145. // Helpers --------------------------------------------------------------------
  146. func init() {
  147. gob.Register([]interface{}{})
  148. }
  149. // Save saves all sessions used during the current request.
  150. func Save(r *http.Request, w http.ResponseWriter) error {
  151. return GetRegistry(r).Save(w)
  152. }
  153. // NewCookie returns an http.Cookie with the options set. It also sets
  154. // the Expires field calculated based on the MaxAge value, for Internet
  155. // Explorer compatibility.
  156. func NewCookie(name, value string, options *Options) *http.Cookie {
  157. cookie := newCookieFromOptions(name, value, options)
  158. if options.MaxAge > 0 {
  159. d := time.Duration(options.MaxAge) * time.Second
  160. cookie.Expires = time.Now().Add(d)
  161. } else if options.MaxAge < 0 {
  162. // Set it to the past to expire now.
  163. cookie.Expires = time.Unix(1, 0)
  164. }
  165. return cookie
  166. }
  167. // Error ----------------------------------------------------------------------
  168. // MultiError stores multiple errors.
  169. //
  170. // Borrowed from the App Engine SDK.
  171. type MultiError []error
  172. func (m MultiError) Error() string {
  173. s, n := "", 0
  174. for _, e := range m {
  175. if e != nil {
  176. if n == 0 {
  177. s = e.Error()
  178. }
  179. n++
  180. }
  181. }
  182. switch n {
  183. case 0:
  184. return "(0 errors)"
  185. case 1:
  186. return s
  187. case 2:
  188. return s + " (and 1 other error)"
  189. }
  190. return fmt.Sprintf("%s (and %d other errors)", s, n-1)
  191. }