wininet: Cache non basic authorization data.

This commit is contained in:
Piotr Caban 2010-07-17 14:08:44 +02:00 committed by Alexandre Julliard
parent 98fb747aa9
commit c398e6fc77
2 changed files with 142 additions and 6 deletions

View file

@ -179,7 +179,7 @@ struct gzip_stream_t {
BOOL end_of_data;
};
typedef struct _authorizationData
typedef struct _basicAuthorizationData
{
struct list entry;
@ -187,9 +187,24 @@ typedef struct _authorizationData
LPWSTR lpszwRealm;
LPSTR lpszAuthorization;
UINT AuthorizationLen;
} basicAuthorizationData;
typedef struct _authorizationData
{
struct list entry;
LPWSTR host;
LPWSTR scheme;
LPWSTR domain;
UINT domain_len;
LPWSTR user;
UINT user_len;
LPWSTR password;
UINT password_len;
} authorizationData;
static struct list basicAuthorizationCache = LIST_INIT(basicAuthorizationCache);
static struct list authorizationCache = LIST_INIT(authorizationCache);
static CRITICAL_SECTION authcache_cs;
static CRITICAL_SECTION_DEBUG critsect_debug =
@ -570,13 +585,13 @@ static void destroy_authinfo( struct HttpAuthInfo *authinfo )
static UINT retrieve_cached_basic_authorization(LPWSTR host, LPWSTR realm, LPSTR *auth_data)
{
authorizationData *ad;
basicAuthorizationData *ad;
UINT rc = 0;
TRACE("Looking for authorization for %s:%s\n",debugstr_w(host),debugstr_w(realm));
EnterCriticalSection(&authcache_cs);
LIST_FOR_EACH_ENTRY(ad, &basicAuthorizationCache, authorizationData, entry)
LIST_FOR_EACH_ENTRY(ad, &basicAuthorizationCache, basicAuthorizationData, entry)
{
if (!strcmpiW(host,ad->lpszwHost) && !strcmpW(realm,ad->lpszwRealm))
{
@ -594,14 +609,14 @@ static UINT retrieve_cached_basic_authorization(LPWSTR host, LPWSTR realm, LPSTR
static void cache_basic_authorization(LPWSTR host, LPWSTR realm, LPSTR auth_data, UINT auth_data_len)
{
struct list *cursor;
authorizationData* ad = NULL;
basicAuthorizationData* ad = NULL;
TRACE("caching authorization for %s:%s = %s\n",debugstr_w(host),debugstr_w(realm),debugstr_an(auth_data,auth_data_len));
EnterCriticalSection(&authcache_cs);
LIST_FOR_EACH(cursor, &basicAuthorizationCache)
{
authorizationData *check = LIST_ENTRY(cursor,authorizationData,entry);
basicAuthorizationData *check = LIST_ENTRY(cursor,basicAuthorizationData,entry);
if (!strcmpiW(host,check->lpszwHost) && !strcmpW(realm,check->lpszwRealm))
{
ad = check;
@ -619,7 +634,7 @@ static void cache_basic_authorization(LPWSTR host, LPWSTR realm, LPSTR auth_data
}
else
{
ad = HeapAlloc(GetProcessHeap(),0,sizeof(authorizationData));
ad = HeapAlloc(GetProcessHeap(),0,sizeof(basicAuthorizationData));
ad->lpszwHost = heap_strdupW(host);
ad->lpszwRealm = heap_strdupW(realm);
ad->lpszAuthorization = HeapAlloc(GetProcessHeap(),0,auth_data_len);
@ -631,6 +646,95 @@ static void cache_basic_authorization(LPWSTR host, LPWSTR realm, LPSTR auth_data
LeaveCriticalSection(&authcache_cs);
}
static BOOL retrieve_cached_authorization(LPWSTR host, LPWSTR scheme,
SEC_WINNT_AUTH_IDENTITY_W *nt_auth_identity)
{
authorizationData *ad;
TRACE("Looking for authorization for %s:%s\n", debugstr_w(host), debugstr_w(scheme));
EnterCriticalSection(&authcache_cs);
LIST_FOR_EACH_ENTRY(ad, &authorizationCache, authorizationData, entry) {
if(!strcmpiW(host, ad->host) && !strcmpiW(scheme, ad->scheme)) {
TRACE("Authorization found in cache\n");
nt_auth_identity->User = heap_strdupW(ad->user);
nt_auth_identity->Password = heap_strdupW(ad->password);
nt_auth_identity->Domain = HeapAlloc(GetProcessHeap(), 0, sizeof(WCHAR)*ad->domain_len);
if(!nt_auth_identity->User || !nt_auth_identity->Password ||
(!nt_auth_identity->Domain && ad->domain_len)) {
HeapFree(GetProcessHeap(), 0, nt_auth_identity->User);
HeapFree(GetProcessHeap(), 0, nt_auth_identity->Password);
HeapFree(GetProcessHeap(), 0, nt_auth_identity->Domain);
break;
}
nt_auth_identity->Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
nt_auth_identity->UserLength = ad->user_len;
nt_auth_identity->PasswordLength = ad->password_len;
memcpy(nt_auth_identity->Domain, ad->domain, sizeof(WCHAR)*ad->domain_len);
nt_auth_identity->DomainLength = ad->domain_len;
LeaveCriticalSection(&authcache_cs);
return TRUE;
}
}
LeaveCriticalSection(&authcache_cs);
return FALSE;
}
static void cache_authorization(LPWSTR host, LPWSTR scheme,
SEC_WINNT_AUTH_IDENTITY_W *nt_auth_identity)
{
authorizationData *ad;
BOOL found = FALSE;
TRACE("Caching authorization for %s:%s\n", debugstr_w(host), debugstr_w(scheme));
EnterCriticalSection(&authcache_cs);
LIST_FOR_EACH_ENTRY(ad, &authorizationCache, authorizationData, entry)
if(!strcmpiW(host, ad->host) && !strcmpiW(scheme, ad->scheme)) {
found = TRUE;
break;
}
if(found) {
HeapFree(GetProcessHeap(), 0, ad->user);
HeapFree(GetProcessHeap(), 0, ad->password);
HeapFree(GetProcessHeap(), 0, ad->domain);
} else {
ad = HeapAlloc(GetProcessHeap(), 0, sizeof(authorizationData));
if(!ad) {
LeaveCriticalSection(&authcache_cs);
return;
}
ad->host = heap_strdupW(host);
ad->scheme = heap_strdupW(scheme);
list_add_head(&authorizationCache, &ad->entry);
}
ad->user = heap_strndupW(nt_auth_identity->User, nt_auth_identity->UserLength);
ad->password = heap_strndupW(nt_auth_identity->Password, nt_auth_identity->PasswordLength);
ad->domain = heap_strndupW(nt_auth_identity->Domain, nt_auth_identity->DomainLength);
ad->user_len = nt_auth_identity->UserLength;
ad->password_len = nt_auth_identity->PasswordLength;
ad->domain_len = nt_auth_identity->DomainLength;
if(!ad->host || !ad->scheme || !ad->user || !ad->password
|| (nt_auth_identity->Domain && !ad->domain)) {
HeapFree(GetProcessHeap(), 0, ad->host);
HeapFree(GetProcessHeap(), 0, ad->scheme);
HeapFree(GetProcessHeap(), 0, ad->user);
HeapFree(GetProcessHeap(), 0, ad->password);
HeapFree(GetProcessHeap(), 0, ad->domain);
list_remove(&ad->entry);
HeapFree(GetProcessHeap(), 0, ad);
}
LeaveCriticalSection(&authcache_cs);
}
static BOOL HTTP_DoAuthorization( http_request_t *lpwhr, LPCWSTR pszAuthValue,
struct HttpAuthInfo **ppAuthInfo,
LPWSTR domain_and_username, LPWSTR password,
@ -705,7 +809,11 @@ static BOOL HTTP_DoAuthorization( http_request_t *lpwhr, LPCWSTR pszAuthValue,
nt_auth_identity.DomainLength = domain ? user - domain - 1 : 0;
nt_auth_identity.Password = password;
nt_auth_identity.PasswordLength = strlenW(nt_auth_identity.Password);
cache_authorization(host, pAuthInfo->scheme, &nt_auth_identity);
}
else if(retrieve_cached_authorization(host, pAuthInfo->scheme, &nt_auth_identity))
pAuthData = &nt_auth_identity;
else
/* use default credentials */
pAuthData = NULL;
@ -715,6 +823,13 @@ static BOOL HTTP_DoAuthorization( http_request_t *lpwhr, LPCWSTR pszAuthValue,
pAuthData, NULL,
NULL, &pAuthInfo->cred,
&exp);
if(pAuthData && !domain_and_username) {
HeapFree(GetProcessHeap(), 0, nt_auth_identity.User);
HeapFree(GetProcessHeap(), 0, nt_auth_identity.Domain);
HeapFree(GetProcessHeap(), 0, nt_auth_identity.Password);
}
if (sec_status == SEC_E_OK)
{
PSecPkgInfoW sec_pkg_info;

View file

@ -71,6 +71,27 @@ static inline LPWSTR heap_strdupW(LPCWSTR str)
return ret;
}
static inline LPWSTR heap_strndupW(LPCWSTR str, UINT max_len)
{
LPWSTR ret;
UINT len;
if(!str)
return NULL;
for(len=0; len<max_len; len++)
if(str[len] == '\0')
break;
ret = HeapAlloc(GetProcessHeap(), 0, sizeof(WCHAR)*(len+1));
if(ret) {
memcpy(ret, str, sizeof(WCHAR)*len);
ret[len] = '\0';
}
return ret;
}
static inline WCHAR *heap_strdupAtoW(const char *str)
{
LPWSTR ret = NULL;