cors.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. package blademaster
  2. import (
  3. "net/http"
  4. "strconv"
  5. "strings"
  6. "time"
  7. "go-common/library/log"
  8. "github.com/pkg/errors"
  9. )
  10. var (
  11. allowOriginHosts = []string{
  12. ".bilibili.com",
  13. ".biligame.com",
  14. ".bilibili.co",
  15. ".im9.com",
  16. ".acg.tv",
  17. ".hdslb.com",
  18. }
  19. )
  20. // CORSConfig represents all available options for the middleware.
  21. type CORSConfig struct {
  22. AllowAllOrigins bool
  23. // AllowedOrigins is a list of origins a cross-domain request can be executed from.
  24. // If the special "*" value is present in the list, all origins will be allowed.
  25. // Default value is []
  26. AllowOrigins []string
  27. // AllowOriginFunc is a custom function to validate the origin. It take the origin
  28. // as argument and returns true if allowed or false otherwise. If this option is
  29. // set, the content of AllowedOrigins is ignored.
  30. AllowOriginFunc func(origin string) bool
  31. // AllowedMethods is a list of methods the client is allowed to use with
  32. // cross-domain requests. Default value is simple methods (GET and POST)
  33. AllowMethods []string
  34. // AllowedHeaders is list of non simple headers the client is allowed to use with
  35. // cross-domain requests.
  36. AllowHeaders []string
  37. // AllowCredentials indicates whether the request can include user credentials like
  38. // cookies, HTTP authentication or client side SSL certificates.
  39. AllowCredentials bool
  40. // ExposedHeaders indicates which headers are safe to expose to the API of a CORS
  41. // API specification
  42. ExposeHeaders []string
  43. // MaxAge indicates how long (in seconds) the results of a preflight request
  44. // can be cached
  45. MaxAge time.Duration
  46. }
  47. type cors struct {
  48. allowAllOrigins bool
  49. allowCredentials bool
  50. allowOriginFunc func(string) bool
  51. allowOrigins []string
  52. normalHeaders http.Header
  53. preflightHeaders http.Header
  54. }
  55. type converter func(string) string
  56. // Validate is check configuration of user defined.
  57. func (c *CORSConfig) Validate() error {
  58. if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
  59. return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed")
  60. }
  61. if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
  62. return errors.New("conflict settings: all origins disabled")
  63. }
  64. for _, origin := range c.AllowOrigins {
  65. if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
  66. return errors.New("bad origin: origins must either be '*' or include http:// or https://")
  67. }
  68. }
  69. return nil
  70. }
  71. // CORS returns the location middleware with default configuration.
  72. func CORS() HandlerFunc {
  73. config := &CORSConfig{
  74. AllowMethods: []string{"GET", "POST"},
  75. AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"},
  76. AllowCredentials: true,
  77. MaxAge: time.Duration(0),
  78. AllowOriginFunc: func(origin string) bool {
  79. for _, host := range allowOriginHosts {
  80. if strings.HasSuffix(strings.ToLower(origin), host) {
  81. return true
  82. }
  83. }
  84. return false
  85. },
  86. }
  87. return newCORS(config)
  88. }
  89. // newCORS returns the location middleware with user-defined custom configuration.
  90. func newCORS(config *CORSConfig) HandlerFunc {
  91. if err := config.Validate(); err != nil {
  92. panic(err.Error())
  93. }
  94. cors := &cors{
  95. allowOriginFunc: config.AllowOriginFunc,
  96. allowAllOrigins: config.AllowAllOrigins,
  97. allowCredentials: config.AllowCredentials,
  98. allowOrigins: normalize(config.AllowOrigins),
  99. normalHeaders: generateNormalHeaders(config),
  100. preflightHeaders: generatePreflightHeaders(config),
  101. }
  102. return func(c *Context) {
  103. cors.applyCORS(c)
  104. }
  105. }
  106. func (cors *cors) applyCORS(c *Context) {
  107. origin := c.Request.Header.Get("Origin")
  108. if len(origin) == 0 {
  109. // request is not a CORS request
  110. return
  111. }
  112. if !cors.validateOrigin(origin) {
  113. log.V(5).Info("The request's Origin header `%s` does not match any of allowed origins.", origin)
  114. c.AbortWithStatus(http.StatusForbidden)
  115. return
  116. }
  117. if c.Request.Method == "OPTIONS" {
  118. cors.handlePreflight(c)
  119. defer c.AbortWithStatus(200)
  120. } else {
  121. cors.handleNormal(c)
  122. }
  123. if !cors.allowAllOrigins {
  124. header := c.Writer.Header()
  125. header.Set("Access-Control-Allow-Origin", origin)
  126. }
  127. }
  128. func (cors *cors) validateOrigin(origin string) bool {
  129. if cors.allowAllOrigins {
  130. return true
  131. }
  132. for _, value := range cors.allowOrigins {
  133. if value == origin {
  134. return true
  135. }
  136. }
  137. if cors.allowOriginFunc != nil {
  138. return cors.allowOriginFunc(origin)
  139. }
  140. return false
  141. }
  142. func (cors *cors) handlePreflight(c *Context) {
  143. header := c.Writer.Header()
  144. for key, value := range cors.preflightHeaders {
  145. header[key] = value
  146. }
  147. }
  148. func (cors *cors) handleNormal(c *Context) {
  149. header := c.Writer.Header()
  150. for key, value := range cors.normalHeaders {
  151. header[key] = value
  152. }
  153. }
  154. func generateNormalHeaders(c *CORSConfig) http.Header {
  155. headers := make(http.Header)
  156. if c.AllowCredentials {
  157. headers.Set("Access-Control-Allow-Credentials", "true")
  158. }
  159. // backport support for early browsers
  160. if len(c.AllowMethods) > 0 {
  161. allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
  162. value := strings.Join(allowMethods, ",")
  163. headers.Set("Access-Control-Allow-Methods", value)
  164. }
  165. if len(c.ExposeHeaders) > 0 {
  166. exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey)
  167. headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ","))
  168. }
  169. if c.AllowAllOrigins {
  170. headers.Set("Access-Control-Allow-Origin", "*")
  171. } else {
  172. headers.Set("Vary", "Origin")
  173. }
  174. return headers
  175. }
  176. func generatePreflightHeaders(c *CORSConfig) http.Header {
  177. headers := make(http.Header)
  178. if c.AllowCredentials {
  179. headers.Set("Access-Control-Allow-Credentials", "true")
  180. }
  181. if len(c.AllowMethods) > 0 {
  182. allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
  183. value := strings.Join(allowMethods, ",")
  184. headers.Set("Access-Control-Allow-Methods", value)
  185. }
  186. if len(c.AllowHeaders) > 0 {
  187. allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey)
  188. value := strings.Join(allowHeaders, ",")
  189. headers.Set("Access-Control-Allow-Headers", value)
  190. }
  191. if c.MaxAge > time.Duration(0) {
  192. value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10)
  193. headers.Set("Access-Control-Max-Age", value)
  194. }
  195. if c.AllowAllOrigins {
  196. headers.Set("Access-Control-Allow-Origin", "*")
  197. } else {
  198. // Always set Vary headers
  199. // see https://github.com/rs/cors/issues/10,
  200. // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
  201. headers.Add("Vary", "Origin")
  202. headers.Add("Vary", "Access-Control-Request-Method")
  203. headers.Add("Vary", "Access-Control-Request-Headers")
  204. }
  205. return headers
  206. }
  207. func normalize(values []string) []string {
  208. if values == nil {
  209. return nil
  210. }
  211. distinctMap := make(map[string]bool, len(values))
  212. normalized := make([]string, 0, len(values))
  213. for _, value := range values {
  214. value = strings.TrimSpace(value)
  215. value = strings.ToLower(value)
  216. if _, seen := distinctMap[value]; !seen {
  217. normalized = append(normalized, value)
  218. distinctMap[value] = true
  219. }
  220. }
  221. return normalized
  222. }
  223. func convert(s []string, c converter) []string {
  224. var out []string
  225. for _, i := range s {
  226. out = append(out, c(i))
  227. }
  228. return out
  229. }