scrrun/dictionary: Compare key values in addition to their hashes.

Signed-off-by: Nikolay Sivov <nsivov@codeweavers.com>
This commit is contained in:
Nikolay Sivov 2023-06-14 06:45:23 +02:00 committed by Alexandre Julliard
parent 7cd77b2a78
commit e6229b5273
2 changed files with 140 additions and 21 deletions

View file

@ -101,7 +101,28 @@ static inline struct list *get_bucket_head(struct dictionary *dict, DWORD hash)
static inline BOOL is_string_key(const VARIANT *key)
{
return V_VT(key) == VT_BSTR || V_VT(key) == (VT_BSTR|VT_BYREF);
return V_VT(key) == VT_BSTR;
}
static inline BOOL is_ptr_key(const VARIANT *key)
{
return V_VT(key) == VT_UNKNOWN || V_VT(key) == VT_DISPATCH;
}
static inline BOOL is_numeric_key(const VARIANT *key)
{
switch (V_VT(key))
{
case VT_UI1:
case VT_I2:
case VT_I4:
case VT_DATE:
case VT_R4:
case VT_R8:
return TRUE;
default:
return FALSE;
}
}
/* Only for VT_BSTR or VT_BSTR|VT_BYREF types */
@ -126,28 +147,59 @@ static inline int strcmp_key(const struct dictionary *dict, const VARIANT *key1,
return dict->method == BinaryCompare ? wcscmp(str1, str2) : wcsicmp(str1, str2);
}
static BOOL is_matching_key(const struct dictionary *dict, const struct keyitem_pair *pair, const VARIANT *key, DWORD hash)
static inline BOOL numeric_key_eq(const VARIANT *key1, const VARIANT *key2)
{
if (is_string_key(key) && is_string_key(&pair->key)) {
if (hash != pair->hash)
return FALSE;
VARIANT v1, v2;
return strcmp_key(dict, key, &pair->key) == 0;
}
if ((is_string_key(key) && !is_string_key(&pair->key)) ||
(!is_string_key(key) && is_string_key(&pair->key)))
VariantInit(&v1);
if (FAILED(VariantChangeType(&v1, key1, 0, VT_R4)))
return FALSE;
/* for numeric keys only check hash */
return hash == pair->hash;
VariantInit(&v2);
if (FAILED(VariantChangeType(&v2, key2, 0, VT_R4)))
return FALSE;
return V_R4(&v1) == V_R4(&v2);
}
static BOOL is_matching_key(const struct dictionary *dict, const struct keyitem_pair *pair, const VARIANT *key, DWORD hash)
{
if (is_string_key(key) != is_string_key(&pair->key))
{
return FALSE;
}
else if (is_string_key(key) && is_string_key(&pair->key))
{
return hash == pair->hash && !strcmp_key(dict, key, &pair->key);
}
else if (is_ptr_key(key) != is_ptr_key(&pair->key))
{
return FALSE;
}
else if (is_ptr_key(key) && is_ptr_key(&pair->key))
{
return hash == pair->hash && V_UNKNOWN(key) == V_UNKNOWN(&pair->key);
}
else if (is_numeric_key(key) != is_numeric_key(&pair->key))
{
return FALSE;
}
else if (is_numeric_key(key) && is_numeric_key(&pair->key))
{
return hash == pair->hash && numeric_key_eq(key, &pair->key);
}
else
{
WARN("Unexpected key type %#x.\n", V_VT(key));
return FALSE;
}
}
static struct keyitem_pair *get_keyitem_pair(struct dictionary *dict, VARIANT *key)
{
struct keyitem_pair *pair;
struct list *head, *entry;
VARIANT hash;
VARIANT hash, v;
HRESULT hr;
hr = IDictionary_get_HashVal(&dict->IDictionary_iface, key, &hash);
@ -158,12 +210,23 @@ static struct keyitem_pair *get_keyitem_pair(struct dictionary *dict, VARIANT *k
if (!head->next || list_empty(head))
return NULL;
VariantInit(&v);
if (FAILED(VariantCopyInd(&v, key)))
return NULL;
entry = list_head(head);
do {
do
{
pair = LIST_ENTRY(entry, struct keyitem_pair, bucket);
if (is_matching_key(dict, pair, key, V_I4(&hash))) return pair;
if (is_matching_key(dict, pair, &v, V_I4(&hash)))
{
VariantClear(&v);
return pair;
}
} while ((entry = list_next(head, entry)));
VariantClear(&v);
return NULL;
}

View file

@ -341,7 +341,7 @@ static void test_hash_value(void)
};
static const FLOAT float_hash_tests[] = {
0.0, -1.0, 100.0, 1.0, 255.0, 1.234
0.0, -1.0, 100.0, 1.0, 255.0, 1.234, 2175.0, 6259.0
};
IDictionary *dict;
@ -848,6 +848,29 @@ static void test_Keys(void)
VariantClear(&keys);
hr = IDictionary_RemoveAll(dict);
ok(hr == S_OK, "Unexpected hr %#lx.\n", hr);
/* Integer key type. */
V_VT(&key) = VT_I4;
V_I4(&key) = 0;
VariantInit(&item);
hr = IDictionary_Add(dict, &key, &item);
ok(hr == S_OK, "Unexpected hr %#lx.\n", hr);
VariantInit(&keys);
hr = IDictionary_Keys(dict, &keys);
ok(hr == S_OK, "Unexpected hr %#lx.\n", hr);
ok(V_VT(&keys) == (VT_ARRAY|VT_VARIANT), "got %d\n", V_VT(&keys));
VariantInit(&key);
index = 0;
hr = SafeArrayGetElement(V_ARRAY(&keys), &index, &key);
ok(hr == S_OK, "Unexpected hr %#lx.\n", hr);
ok(V_VT(&key) == VT_I4, "got %d\n", V_VT(&key));
VariantClear(&keys);
IDictionary_Release(dict);
}
@ -910,7 +933,7 @@ static void test_Item(void)
static void test_Add(void)
{
VARIANT key, item;
VARIANT item, key1, key2, hash1, hash2;
IDictionary *dict;
HRESULT hr;
BSTR str;
@ -920,19 +943,52 @@ static void test_Add(void)
ok(hr == S_OK, "Unexpected hr %#lx.\n", hr);
str = SysAllocString(L"testW");
V_VT(&key) = VT_I2;
V_I2(&key) = 1;
V_VT(&key1) = VT_I2;
V_I2(&key1) = 1;
V_VT(&item) = VT_BSTR|VT_BYREF;
V_BSTRREF(&item) = &str;
hr = IDictionary_Add(dict, &key, &item);
hr = IDictionary_Add(dict, &key1, &item);
ok(hr == S_OK, "Unexpected hr %#lx.\n", hr);
hr = IDictionary_get_Item(dict, &key, &item);
hr = IDictionary_get_Item(dict, &key1, &item);
ok(hr == S_OK, "Unexpected hr %#lx.\n", hr);
ok(V_VT(&item) == VT_BSTR, "got %d\n", V_VT(&item));
SysFreeString(str);
/* Items with matching key hashes, float keys. */
V_VT(&key1) = VT_R4;
V_R4(&key1) = 2175.0;
V_VT(&key2) = VT_R4;
V_R4(&key2) = 6259.0;
VariantInit(&hash1);
hr = IDictionary_get_HashVal(dict, &key1, &hash1);
ok(hr == S_OK, "Unexpected hr %#lx.\n", hr);
VariantInit(&hash2);
hr = IDictionary_get_HashVal(dict, &key1, &hash2);
ok(hr == S_OK, "Unexpected hr %#lx.\n", hr);
ok(V_VT(&hash1) == VT_I4, "Unexpected type %d.\n", V_VT(&hash1));
ok(V_VT(&hash2) == VT_I4, "Unexpected type %d.\n", V_VT(&hash2));
ok(V_I4(&hash1) == V_I4(&hash2), "Unexpected hash %#lx.\n", V_I4(&hash1));
V_VT(&item) = VT_BSTR;
V_BSTR(&item) = SysAllocString(L"float1");
hr = IDictionary_Add(dict, &key1, &item);
ok(hr == S_OK, "Unexpected hr %#lx.\n", hr);
hr = IDictionary_Add(dict, &key2, &item);
ok(hr == S_OK, "Unexpected hr %#lx.\n", hr);
V_VT(&key1) = VT_I4;
V_I4(&key1) = 2175;
hr = IDictionary_Add(dict, &key1, &item);
ok(hr == CTL_E_KEY_ALREADY_EXISTS, "Unexpected hr %#lx.\n", hr);
VariantClear(&item);
IDictionary_Release(dict);
}