oauth2.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. // Copyright 2014 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Package oauth2 contains Martini handlers to provide
  15. // user login via an OAuth 2.0 backend.
  16. package oauth2
  17. import (
  18. "encoding/json"
  19. "fmt"
  20. "net/http"
  21. "net/url"
  22. "strings"
  23. "time"
  24. "code.google.com/p/goauth2/oauth"
  25. "github.com/go-martini/martini"
  26. "github.com/gogits/session"
  27. "github.com/gogits/gogs/modules/log"
  28. "github.com/gogits/gogs/modules/middleware"
  29. )
  30. const (
  31. keyToken = "oauth2_token"
  32. keyNextPage = "next"
  33. )
  34. var (
  35. // Path to handle OAuth 2.0 logins.
  36. PathLogin = "/login"
  37. // Path to handle OAuth 2.0 logouts.
  38. PathLogout = "/logout"
  39. // Path to handle callback from OAuth 2.0 backend
  40. // to exchange credentials.
  41. PathCallback = "/oauth2callback"
  42. // Path to handle error cases.
  43. PathError = "/oauth2error"
  44. )
  45. // Represents OAuth2 backend options.
  46. type Options struct {
  47. ClientId string
  48. ClientSecret string
  49. RedirectURL string
  50. Scopes []string
  51. AuthUrl string
  52. TokenUrl string
  53. }
  54. // Represents a container that contains
  55. // user's OAuth 2.0 access and refresh tokens.
  56. type Tokens interface {
  57. Access() string
  58. Refresh() string
  59. IsExpired() bool
  60. ExpiryTime() time.Time
  61. ExtraData() map[string]string
  62. }
  63. type token struct {
  64. oauth.Token
  65. }
  66. func (t *token) ExtraData() map[string]string {
  67. return t.Extra
  68. }
  69. // Returns the access token.
  70. func (t *token) Access() string {
  71. return t.AccessToken
  72. }
  73. // Returns the refresh token.
  74. func (t *token) Refresh() string {
  75. return t.RefreshToken
  76. }
  77. // Returns whether the access token is
  78. // expired or not.
  79. func (t *token) IsExpired() bool {
  80. if t == nil {
  81. return true
  82. }
  83. return t.Expired()
  84. }
  85. // Returns the expiry time of the user's
  86. // access token.
  87. func (t *token) ExpiryTime() time.Time {
  88. return t.Expiry
  89. }
  90. // Formats tokens into string.
  91. func (t *token) String() string {
  92. return fmt.Sprintf("tokens: %v", t)
  93. }
  94. // Returns a new Google OAuth 2.0 backend endpoint.
  95. func Google(opts *Options) martini.Handler {
  96. opts.AuthUrl = "https://accounts.google.com/o/oauth2/auth"
  97. opts.TokenUrl = "https://accounts.google.com/o/oauth2/token"
  98. return NewOAuth2Provider(opts)
  99. }
  100. // Returns a new Github OAuth 2.0 backend endpoint.
  101. func Github(opts *Options) martini.Handler {
  102. opts.AuthUrl = "https://github.com/login/oauth/authorize"
  103. opts.TokenUrl = "https://github.com/login/oauth/access_token"
  104. return NewOAuth2Provider(opts)
  105. }
  106. func Facebook(opts *Options) martini.Handler {
  107. opts.AuthUrl = "https://www.facebook.com/dialog/oauth"
  108. opts.TokenUrl = "https://graph.facebook.com/oauth/access_token"
  109. return NewOAuth2Provider(opts)
  110. }
  111. // Returns a generic OAuth 2.0 backend endpoint.
  112. func NewOAuth2Provider(opts *Options) martini.Handler {
  113. config := &oauth.Config{
  114. ClientId: opts.ClientId,
  115. ClientSecret: opts.ClientSecret,
  116. RedirectURL: opts.RedirectURL,
  117. Scope: strings.Join(opts.Scopes, " "),
  118. AuthURL: opts.AuthUrl,
  119. TokenURL: opts.TokenUrl,
  120. }
  121. transport := &oauth.Transport{
  122. Config: config,
  123. Transport: http.DefaultTransport,
  124. }
  125. return func(c martini.Context, ctx *middleware.Context) {
  126. if ctx.Req.Method == "GET" {
  127. switch ctx.Req.URL.Path {
  128. case PathLogin:
  129. login(transport, ctx)
  130. case PathLogout:
  131. logout(transport, ctx)
  132. case PathCallback:
  133. handleOAuth2Callback(transport, ctx)
  134. }
  135. }
  136. tk := unmarshallToken(ctx.Session)
  137. if tk != nil {
  138. // check if the access token is expired
  139. if tk.IsExpired() && tk.Refresh() == "" {
  140. ctx.Session.Delete(keyToken)
  141. tk = nil
  142. }
  143. }
  144. // Inject tokens.
  145. c.MapTo(tk, (*Tokens)(nil))
  146. }
  147. }
  148. // Handler that redirects user to the login page
  149. // if user is not logged in.
  150. // Sample usage:
  151. // m.Get("/login-required", oauth2.LoginRequired, func() ... {})
  152. var LoginRequired martini.Handler = func() martini.Handler {
  153. return func(c martini.Context, ctx *middleware.Context) {
  154. token := unmarshallToken(ctx.Session)
  155. if token == nil || token.IsExpired() {
  156. next := url.QueryEscape(ctx.Req.URL.RequestURI())
  157. ctx.Redirect(PathLogin + "?next=" + next)
  158. return
  159. }
  160. }
  161. }()
  162. func login(t *oauth.Transport, ctx *middleware.Context) {
  163. next := extractPath(ctx.Query(keyNextPage))
  164. if ctx.Session.Get(keyToken) == nil {
  165. // User is not logged in.
  166. ctx.Redirect(t.Config.AuthCodeURL(next))
  167. return
  168. }
  169. // No need to login, redirect to the next page.
  170. ctx.Redirect(next)
  171. }
  172. func logout(t *oauth.Transport, ctx *middleware.Context) {
  173. next := extractPath(ctx.Query(keyNextPage))
  174. ctx.Session.Delete(keyToken)
  175. ctx.Redirect(next)
  176. }
  177. func handleOAuth2Callback(t *oauth.Transport, ctx *middleware.Context) {
  178. if errMsg := ctx.Query("error_description"); len(errMsg) > 0 {
  179. log.Error("oauth2.handleOAuth2Callback: %s", errMsg)
  180. return
  181. }
  182. next := extractPath(ctx.Query("state"))
  183. code := ctx.Query("code")
  184. tk, err := t.Exchange(code)
  185. if err != nil {
  186. // Pass the error message, or allow dev to provide its own
  187. // error handler.
  188. log.Error("oauth2.handleOAuth2Callback(token.Exchange): %v", err)
  189. // ctx.Redirect(PathError)
  190. return
  191. }
  192. // Store the credentials in the session.
  193. val, _ := json.Marshal(tk)
  194. ctx.Session.Set(keyToken, val)
  195. ctx.Redirect(next)
  196. }
  197. func unmarshallToken(s session.SessionStore) (t *token) {
  198. if s.Get(keyToken) == nil {
  199. return
  200. }
  201. data := s.Get(keyToken).([]byte)
  202. var tk oauth.Token
  203. json.Unmarshal(data, &tk)
  204. return &token{tk}
  205. }
  206. func extractPath(next string) string {
  207. n, err := url.Parse(next)
  208. if err != nil {
  209. return "/"
  210. }
  211. return n.Path
  212. }