@ -3,10 +3,24 @@ package ntlmssp
import (
"bytes"
"encoding/base64"
"io"
"io/ioutil"
"net/http"
"strings"
)
// GetDomain : parse domain name from based on slashes in the input
func GetDomain ( user string ) ( string , string ) {
domain := ""
if strings . Contains ( user , "\\" ) {
ucomponents := strings . SplitN ( user , "\\" , 2 )
domain = ucomponents [ 0 ]
user = ucomponents [ 1 ]
}
return user , domain
}
//Negotiator is a http.Roundtripper decorator that automatically
//converts basic authentication to NTLM/Negotiate authentication when appropriate.
type Negotiator struct { http . RoundTripper }
@ -47,9 +61,10 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error)
}
resauth := authheader ( res . Header . Get ( "Www-Authenticate" ) )
if ! resauth . IsNegotiate ( ) {
if ! resauth . IsNegotiate ( ) && ! resauth . IsNTLM ( ) {
// Unauthorized, Negotiate not requested, let's try with basic auth
req . Header . Set ( "Authorization" , string ( reqauth ) )
io . Copy ( ioutil . Discard , res . Body )
res . Body . Close ( )
req . Body = ioutil . NopCloser ( bytes . NewReader ( body . Bytes ( ) ) )
@ -63,8 +78,9 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error)
resauth = authheader ( res . Header . Get ( "Www-Authenticate" ) )
}
if resauth . IsNegotiate ( ) {
if resauth . IsNegotiate ( ) || resauth . IsNTLM ( ) {
// 401 with request:Basic and response:Negotiate
io . Copy ( ioutil . Discard , res . Body )
res . Body . Close ( )
// recycle credentials
@ -73,9 +89,21 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error)
return nil , err
}
// get domain from username
domain := ""
u , domain = GetDomain ( u )
// send negotiate
negotiateMessage := NewNegotiateMessage ( )
req . Header . Set ( "Authorization" , "Negotiate " + base64 . StdEncoding . EncodeToString ( negotiateMessage ) )
negotiateMessage , err := NewNegotiateMessage ( domain , "" )
if err != nil {
return nil , err
}
if resauth . IsNTLM ( ) {
req . Header . Set ( "Authorization" , "NTLM " + base64 . StdEncoding . EncodeToString ( negotiateMessage ) )
} else {
req . Header . Set ( "Authorization" , "Negotiate " + base64 . StdEncoding . EncodeToString ( negotiateMessage ) )
}
req . Body = ioutil . NopCloser ( bytes . NewReader ( body . Bytes ( ) ) )
res , err = rt . RoundTrip ( req )
@ -89,10 +117,11 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error)
if err != nil {
return nil , err
}
if ! resauth . IsNegotiate ( ) || len ( challengeMessage ) == 0 {
if ! ( resauth . IsNegotiate ( ) || resauth . IsNTLM ( ) ) || len ( challengeMessage ) == 0 {
// Negotiation failed, let client deal with response
return res , nil
}
io . Copy ( ioutil . Discard , res . Body )
res . Body . Close ( )
// send authenticate
@ -100,7 +129,12 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error)
if err != nil {
return nil , err
}
req . Header . Set ( "Authorization" , "Negotiate " + base64 . StdEncoding . EncodeToString ( authenticateMessage ) )
if resauth . IsNTLM ( ) {
req . Header . Set ( "Authorization" , "NTLM " + base64 . StdEncoding . EncodeToString ( authenticateMessage ) )
} else {
req . Header . Set ( "Authorization" , "Negotiate " + base64 . StdEncoding . EncodeToString ( authenticateMessage ) )
}
req . Body = ioutil . NopCloser ( bytes . NewReader ( body . Bytes ( ) ) )
res , err = rt . RoundTrip ( req )