diff --git a/dlls/scrrun/dictionary.c b/dlls/scrrun/dictionary.c index bb9c3610d1b..d69b513efc2 100644 --- a/dlls/scrrun/dictionary.c +++ b/dlls/scrrun/dictionary.c @@ -101,7 +101,28 @@ static inline struct list *get_bucket_head(struct dictionary *dict, DWORD hash) static inline BOOL is_string_key(const VARIANT *key) { - return V_VT(key) == VT_BSTR || V_VT(key) == (VT_BSTR|VT_BYREF); + return V_VT(key) == VT_BSTR; +} + +static inline BOOL is_ptr_key(const VARIANT *key) +{ + return V_VT(key) == VT_UNKNOWN || V_VT(key) == VT_DISPATCH; +} + +static inline BOOL is_numeric_key(const VARIANT *key) +{ + switch (V_VT(key)) + { + case VT_UI1: + case VT_I2: + case VT_I4: + case VT_DATE: + case VT_R4: + case VT_R8: + return TRUE; + default: + return FALSE; + } } /* Only for VT_BSTR or VT_BSTR|VT_BYREF types */ @@ -126,28 +147,59 @@ static inline int strcmp_key(const struct dictionary *dict, const VARIANT *key1, return dict->method == BinaryCompare ? wcscmp(str1, str2) : wcsicmp(str1, str2); } -static BOOL is_matching_key(const struct dictionary *dict, const struct keyitem_pair *pair, const VARIANT *key, DWORD hash) +static inline BOOL numeric_key_eq(const VARIANT *key1, const VARIANT *key2) { - if (is_string_key(key) && is_string_key(&pair->key)) { - if (hash != pair->hash) - return FALSE; + VARIANT v1, v2; - return strcmp_key(dict, key, &pair->key) == 0; - } - - if ((is_string_key(key) && !is_string_key(&pair->key)) || - (!is_string_key(key) && is_string_key(&pair->key))) + VariantInit(&v1); + if (FAILED(VariantChangeType(&v1, key1, 0, VT_R4))) return FALSE; - /* for numeric keys only check hash */ - return hash == pair->hash; + VariantInit(&v2); + if (FAILED(VariantChangeType(&v2, key2, 0, VT_R4))) + return FALSE; + + return V_R4(&v1) == V_R4(&v2); +} + +static BOOL is_matching_key(const struct dictionary *dict, const struct keyitem_pair *pair, const VARIANT *key, DWORD hash) +{ + if (is_string_key(key) != is_string_key(&pair->key)) + { + return FALSE; + } + else if (is_string_key(key) && is_string_key(&pair->key)) + { + return hash == pair->hash && !strcmp_key(dict, key, &pair->key); + } + else if (is_ptr_key(key) != is_ptr_key(&pair->key)) + { + return FALSE; + } + else if (is_ptr_key(key) && is_ptr_key(&pair->key)) + { + return hash == pair->hash && V_UNKNOWN(key) == V_UNKNOWN(&pair->key); + } + else if (is_numeric_key(key) != is_numeric_key(&pair->key)) + { + return FALSE; + } + else if (is_numeric_key(key) && is_numeric_key(&pair->key)) + { + return hash == pair->hash && numeric_key_eq(key, &pair->key); + } + else + { + WARN("Unexpected key type %#x.\n", V_VT(key)); + return FALSE; + } } static struct keyitem_pair *get_keyitem_pair(struct dictionary *dict, VARIANT *key) { struct keyitem_pair *pair; struct list *head, *entry; - VARIANT hash; + VARIANT hash, v; HRESULT hr; hr = IDictionary_get_HashVal(&dict->IDictionary_iface, key, &hash); @@ -158,12 +210,23 @@ static struct keyitem_pair *get_keyitem_pair(struct dictionary *dict, VARIANT *k if (!head->next || list_empty(head)) return NULL; + VariantInit(&v); + if (FAILED(VariantCopyInd(&v, key))) + return NULL; + entry = list_head(head); - do { + do + { pair = LIST_ENTRY(entry, struct keyitem_pair, bucket); - if (is_matching_key(dict, pair, key, V_I4(&hash))) return pair; + if (is_matching_key(dict, pair, &v, V_I4(&hash))) + { + VariantClear(&v); + return pair; + } } while ((entry = list_next(head, entry))); + VariantClear(&v); + return NULL; } diff --git a/dlls/scrrun/tests/dictionary.c b/dlls/scrrun/tests/dictionary.c index 16a195955ac..3232ef7321a 100644 --- a/dlls/scrrun/tests/dictionary.c +++ b/dlls/scrrun/tests/dictionary.c @@ -341,7 +341,7 @@ static void test_hash_value(void) }; static const FLOAT float_hash_tests[] = { - 0.0, -1.0, 100.0, 1.0, 255.0, 1.234 + 0.0, -1.0, 100.0, 1.0, 255.0, 1.234, 2175.0, 6259.0 }; IDictionary *dict; @@ -848,6 +848,29 @@ static void test_Keys(void) VariantClear(&keys); + hr = IDictionary_RemoveAll(dict); + ok(hr == S_OK, "Unexpected hr %#lx.\n", hr); + + /* Integer key type. */ + V_VT(&key) = VT_I4; + V_I4(&key) = 0; + VariantInit(&item); + hr = IDictionary_Add(dict, &key, &item); + ok(hr == S_OK, "Unexpected hr %#lx.\n", hr); + + VariantInit(&keys); + hr = IDictionary_Keys(dict, &keys); + ok(hr == S_OK, "Unexpected hr %#lx.\n", hr); + ok(V_VT(&keys) == (VT_ARRAY|VT_VARIANT), "got %d\n", V_VT(&keys)); + + VariantInit(&key); + index = 0; + hr = SafeArrayGetElement(V_ARRAY(&keys), &index, &key); + ok(hr == S_OK, "Unexpected hr %#lx.\n", hr); + ok(V_VT(&key) == VT_I4, "got %d\n", V_VT(&key)); + + VariantClear(&keys); + IDictionary_Release(dict); } @@ -910,7 +933,7 @@ static void test_Item(void) static void test_Add(void) { - VARIANT key, item; + VARIANT item, key1, key2, hash1, hash2; IDictionary *dict; HRESULT hr; BSTR str; @@ -920,19 +943,52 @@ static void test_Add(void) ok(hr == S_OK, "Unexpected hr %#lx.\n", hr); str = SysAllocString(L"testW"); - V_VT(&key) = VT_I2; - V_I2(&key) = 1; + V_VT(&key1) = VT_I2; + V_I2(&key1) = 1; V_VT(&item) = VT_BSTR|VT_BYREF; V_BSTRREF(&item) = &str; - hr = IDictionary_Add(dict, &key, &item); + hr = IDictionary_Add(dict, &key1, &item); ok(hr == S_OK, "Unexpected hr %#lx.\n", hr); - hr = IDictionary_get_Item(dict, &key, &item); + hr = IDictionary_get_Item(dict, &key1, &item); ok(hr == S_OK, "Unexpected hr %#lx.\n", hr); ok(V_VT(&item) == VT_BSTR, "got %d\n", V_VT(&item)); SysFreeString(str); + /* Items with matching key hashes, float keys. */ + V_VT(&key1) = VT_R4; + V_R4(&key1) = 2175.0; + + V_VT(&key2) = VT_R4; + V_R4(&key2) = 6259.0; + + VariantInit(&hash1); + hr = IDictionary_get_HashVal(dict, &key1, &hash1); + ok(hr == S_OK, "Unexpected hr %#lx.\n", hr); + VariantInit(&hash2); + hr = IDictionary_get_HashVal(dict, &key1, &hash2); + ok(hr == S_OK, "Unexpected hr %#lx.\n", hr); + ok(V_VT(&hash1) == VT_I4, "Unexpected type %d.\n", V_VT(&hash1)); + ok(V_VT(&hash2) == VT_I4, "Unexpected type %d.\n", V_VT(&hash2)); + ok(V_I4(&hash1) == V_I4(&hash2), "Unexpected hash %#lx.\n", V_I4(&hash1)); + + V_VT(&item) = VT_BSTR; + V_BSTR(&item) = SysAllocString(L"float1"); + + hr = IDictionary_Add(dict, &key1, &item); + ok(hr == S_OK, "Unexpected hr %#lx.\n", hr); + hr = IDictionary_Add(dict, &key2, &item); + ok(hr == S_OK, "Unexpected hr %#lx.\n", hr); + + V_VT(&key1) = VT_I4; + V_I4(&key1) = 2175; + + hr = IDictionary_Add(dict, &key1, &item); + ok(hr == CTL_E_KEY_ALREADY_EXISTS, "Unexpected hr %#lx.\n", hr); + + VariantClear(&item); + IDictionary_Release(dict); }