transport.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. // Copyright 2014 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package oauth2
  5. import (
  6. "errors"
  7. "io"
  8. "net/http"
  9. "sync"
  10. )
  11. // Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
  12. // wrapping a base RoundTripper and adding an Authorization header
  13. // with a token from the supplied Sources.
  14. //
  15. // Transport is a low-level mechanism. Most code will use the
  16. // higher-level Config.Client method instead.
  17. type Transport struct {
  18. // Source supplies the token to add to outgoing requests'
  19. // Authorization headers.
  20. Source TokenSource
  21. // Base is the base RoundTripper used to make HTTP requests.
  22. // If nil, http.DefaultTransport is used.
  23. Base http.RoundTripper
  24. mu sync.Mutex // guards modReq
  25. modReq map[*http.Request]*http.Request // original -> modified
  26. }
  27. // RoundTrip authorizes and authenticates the request with an
  28. // access token. If no token exists or token is expired,
  29. // tries to refresh/fetch a new token.
  30. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
  31. if t.Source == nil {
  32. return nil, errors.New("oauth2: Transport's Source is nil")
  33. }
  34. token, err := t.Source.Token()
  35. if err != nil {
  36. return nil, err
  37. }
  38. req2 := cloneRequest(req) // per RoundTripper contract
  39. token.SetAuthHeader(req2)
  40. t.setModReq(req, req2)
  41. res, err := t.base().RoundTrip(req2)
  42. if err != nil {
  43. t.setModReq(req, nil)
  44. return nil, err
  45. }
  46. res.Body = &onEOFReader{
  47. rc: res.Body,
  48. fn: func() { t.setModReq(req, nil) },
  49. }
  50. return res, nil
  51. }
  52. // CancelRequest cancels an in-flight request by closing its connection.
  53. func (t *Transport) CancelRequest(req *http.Request) {
  54. type canceler interface {
  55. CancelRequest(*http.Request)
  56. }
  57. if cr, ok := t.base().(canceler); ok {
  58. t.mu.Lock()
  59. modReq := t.modReq[req]
  60. delete(t.modReq, req)
  61. t.mu.Unlock()
  62. cr.CancelRequest(modReq)
  63. }
  64. }
  65. func (t *Transport) base() http.RoundTripper {
  66. if t.Base != nil {
  67. return t.Base
  68. }
  69. return http.DefaultTransport
  70. }
  71. func (t *Transport) setModReq(orig, mod *http.Request) {
  72. t.mu.Lock()
  73. defer t.mu.Unlock()
  74. if t.modReq == nil {
  75. t.modReq = make(map[*http.Request]*http.Request)
  76. }
  77. if mod == nil {
  78. delete(t.modReq, orig)
  79. } else {
  80. t.modReq[orig] = mod
  81. }
  82. }
  83. // cloneRequest returns a clone of the provided *http.Request.
  84. // The clone is a shallow copy of the struct and its Header map.
  85. func cloneRequest(r *http.Request) *http.Request {
  86. // shallow copy of the struct
  87. r2 := new(http.Request)
  88. *r2 = *r
  89. // deep copy of the Header
  90. r2.Header = make(http.Header, len(r.Header))
  91. for k, s := range r.Header {
  92. r2.Header[k] = append([]string(nil), s...)
  93. }
  94. return r2
  95. }
  96. type onEOFReader struct {
  97. rc io.ReadCloser
  98. fn func()
  99. }
  100. func (r *onEOFReader) Read(p []byte) (n int, err error) {
  101. n, err = r.rc.Read(p)
  102. if err == io.EOF {
  103. r.runFunc()
  104. }
  105. return
  106. }
  107. func (r *onEOFReader) Close() error {
  108. err := r.rc.Close()
  109. r.runFunc()
  110. return err
  111. }
  112. func (r *onEOFReader) runFunc() {
  113. if fn := r.fn; fn != nil {
  114. fn()
  115. r.fn = nil
  116. }
  117. }