From 2e6ed06fdcfbb9b71dd550505e784f5d1041eff6 Mon Sep 17 00:00:00 2001 From: James Hawkins Date: Wed, 1 Aug 2007 14:25:45 -0700 Subject: [PATCH] msi: Reimplement joins to allow joining any number of tables, each of arbitrary size. --- dlls/msi/join.c | 308 +++++++++++++++++++++++--------------------- dlls/msi/query.h | 3 +- dlls/msi/sql.y | 36 +++++- dlls/msi/tests/db.c | 25 ++-- 4 files changed, 200 insertions(+), 172 deletions(-) diff --git a/dlls/msi/join.c b/dlls/msi/join.c index e0710789750..325f58e36f0 100644 --- a/dlls/msi/join.c +++ b/dlls/msi/join.c @@ -23,7 +23,6 @@ #include "windef.h" #include "winbase.h" #include "winerror.h" -#include "wine/debug.h" #include "msi.h" #include "msiquery.h" #include "objbase.h" @@ -31,157 +30,150 @@ #include "msipriv.h" #include "query.h" +#include "wine/debug.h" +#include "wine/unicode.h" + WINE_DEFAULT_DEBUG_CHANNEL(msidb); +typedef struct tagJOINTABLE +{ + struct list entry; + MSIVIEW *view; + UINT columns; + UINT rows; + UINT next_rows; +} JOINTABLE; + typedef struct tagMSIJOINVIEW { MSIVIEW view; MSIDATABASE *db; - MSIVIEW *left, *right; - UINT left_count, right_count; - UINT left_rows, right_rows; + struct list tables; + UINT columns; + UINT rows; } MSIJOINVIEW; static UINT JOIN_fetch_int( struct tagMSIVIEW *view, UINT row, UINT col, UINT *val ) { MSIJOINVIEW *jv = (MSIJOINVIEW*)view; - MSIVIEW *table; + JOINTABLE *table; + UINT cols = 0; + UINT prev_rows = 1; - TRACE("%p %d %d %p\n", jv, row, col, val ); + TRACE("%d, %d\n", row, col); - if( !jv->left || !jv->right ) + if (col == 0 || col > jv->columns) return ERROR_FUNCTION_FAILED; - if( (col==0) || (col>(jv->left_count + jv->right_count)) ) + if (row >= jv->rows) return ERROR_FUNCTION_FAILED; - if( row >= (jv->left_rows * jv->right_rows) ) - return ERROR_FUNCTION_FAILED; - - if( col <= jv->left_count ) + LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry) { - table = jv->left; - row = (row/jv->right_rows); - } - else - { - table = jv->right; - row = (row % jv->right_rows); - col -= jv->left_count; + if (col <= cols + table->columns) + { + row = (row % (jv->rows / table->next_rows)) / prev_rows; + col -= cols; + break; + } + + prev_rows = table->rows; + cols += table->columns; } - return table->ops->fetch_int( table, row, col, val ); + return table->view->ops->fetch_int( table->view, row, col, val ); } static UINT JOIN_fetch_stream( struct tagMSIVIEW *view, UINT row, UINT col, IStream **stm) { MSIJOINVIEW *jv = (MSIJOINVIEW*)view; - MSIVIEW *table; + JOINTABLE *table; + UINT cols = 0; + UINT prev_rows = 1; TRACE("%p %d %d %p\n", jv, row, col, stm ); - if( !jv->left || !jv->right ) + if (col == 0 || col > jv->columns) return ERROR_FUNCTION_FAILED; - if( (col==0) || (col>(jv->left_count + jv->right_count)) ) + if (row >= jv->rows) return ERROR_FUNCTION_FAILED; - if( row >= jv->left_rows * jv->right_rows ) - return ERROR_FUNCTION_FAILED; - - if( row <= jv->left_count ) + LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry) { - table = jv->left; - row = (row/jv->right_rows); - } - else - { - table = jv->right; - row = (row % jv->right_rows); - col -= jv->left_count; + if (col <= cols + table->columns) + { + row = (row % (jv->rows / table->next_rows)) / prev_rows; + col -= cols; + break; + } + + prev_rows = table->rows; + cols += table->columns; } - return table->ops->fetch_stream( table, row, col, stm ); + return table->view->ops->fetch_stream( table->view, row, col, stm ); } static UINT JOIN_get_row( struct tagMSIVIEW *view, UINT row, MSIRECORD **rec ) { - MSIJOINVIEW *jv = (MSIJOINVIEW*)view; - MSIVIEW *table; - - TRACE("%p %d %p\n", jv, row, rec ); - - if( !jv->left || !jv->right ) - return ERROR_FUNCTION_FAILED; - - if( row >= jv->left_rows * jv->right_rows ) - return ERROR_FUNCTION_FAILED; - - if( row <= jv->left_count ) - { - table = jv->left; - row = (row/jv->right_rows); - } - else - { - table = jv->right; - row = (row % jv->right_rows); - } - - return table->ops->get_row(table, row, rec); + FIXME("(%p, %d, %p): stub!\n", view, row, rec); + return ERROR_FUNCTION_FAILED; } static UINT JOIN_execute( struct tagMSIVIEW *view, MSIRECORD *record ) { MSIJOINVIEW *jv = (MSIJOINVIEW*)view; - UINT r, *ldata = NULL, *rdata = NULL; + JOINTABLE *table; + UINT r, rows; TRACE("%p %p\n", jv, record); - if( !jv->left || !jv->right ) - return ERROR_FUNCTION_FAILED; - - r = jv->left->ops->execute( jv->left, NULL ); - if (r != ERROR_SUCCESS) - return r; - - r = jv->right->ops->execute( jv->right, NULL ); - if (r != ERROR_SUCCESS) - return r; - - /* get the number of rows in each table */ - r = jv->left->ops->get_dimensions( jv->left, &jv->left_rows, NULL ); - if( r != ERROR_SUCCESS ) + LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry) { - ERR("can't get left table dimensions\n"); - goto end; + table->view->ops->execute(table->view, NULL); + + r = table->view->ops->get_dimensions(table->view, &table->rows, NULL); + if (r != ERROR_SUCCESS) + { + ERR("failed to get table dimensions\n"); + return r; + } + + /* each table must have at least one row */ + if (table->rows == 0) + { + jv->rows = 0; + return ERROR_SUCCESS; + } + + if (jv->rows == 0) + jv->rows = table->rows; + else + jv->rows *= table->rows; } - r = jv->right->ops->get_dimensions( jv->right, &jv->right_rows, NULL ); - if( r != ERROR_SUCCESS ) + rows = jv->rows; + LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry) { - ERR("can't get right table dimensions\n"); - goto end; + rows /= table->rows; + table->next_rows = rows; } -end: - msi_free( ldata ); - msi_free( rdata ); - - return r; + return ERROR_SUCCESS; } static UINT JOIN_close( struct tagMSIVIEW *view ) { MSIJOINVIEW *jv = (MSIJOINVIEW*)view; + JOINTABLE *table; TRACE("%p\n", jv ); - if( !jv->left || !jv->right ) - return ERROR_FUNCTION_FAILED; - - jv->left->ops->close( jv->left ); - jv->right->ops->close( jv->right ); + LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry) + { + table->view->ops->close(table->view); + } return ERROR_SUCCESS; } @@ -192,16 +184,11 @@ static UINT JOIN_get_dimensions( struct tagMSIVIEW *view, UINT *rows, UINT *cols TRACE("%p %p %p\n", jv, rows, cols ); - if( cols ) - *cols = jv->left_count + jv->right_count; + if (cols) + *cols = jv->columns; - if( rows ) - { - if( !jv->left || !jv->right ) - return ERROR_FUNCTION_FAILED; - - *rows = jv->left_rows * jv->right_rows; - } + if (rows) + *rows = jv->rows; return ERROR_SUCCESS; } @@ -210,48 +197,46 @@ static UINT JOIN_get_column_info( struct tagMSIVIEW *view, UINT n, LPWSTR *name, UINT *type ) { MSIJOINVIEW *jv = (MSIJOINVIEW*)view; + JOINTABLE *table; + UINT cols = 0; TRACE("%p %d %p %p\n", jv, n, name, type ); - if( !jv->left || !jv->right ) + if (n == 0 || n > jv->columns) return ERROR_FUNCTION_FAILED; - if( (n==0) || (n>(jv->left_count + jv->right_count)) ) - return ERROR_FUNCTION_FAILED; + LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry) + { + if (n <= cols + table->columns) + return table->view->ops->get_column_info(table->view, n - cols, name, type); - if( n <= jv->left_count ) - return jv->left->ops->get_column_info( jv->left, n, name, type ); + cols += table->columns; + } - n = n - jv->left_count; - - return jv->right->ops->get_column_info( jv->right, n, name, type ); + return ERROR_FUNCTION_FAILED; } static UINT JOIN_modify( struct tagMSIVIEW *view, MSIMODIFY eModifyMode, MSIRECORD *rec, UINT row ) { - MSIJOINVIEW *jv = (MSIJOINVIEW*)view; - - TRACE("%p %d %p\n", jv, eModifyMode, rec ); - + TRACE("%p %d %p\n", view, eModifyMode, rec); return ERROR_FUNCTION_FAILED; } static UINT JOIN_delete( struct tagMSIVIEW *view ) { MSIJOINVIEW *jv = (MSIJOINVIEW*)view; + JOINTABLE *table; TRACE("%p\n", jv ); - if( jv->left ) - jv->left->ops->delete( jv->left ); - jv->left = NULL; + LIST_FOR_EACH_ENTRY(table, &jv->tables, JOINTABLE, entry) + { + table->view->ops->delete(table->view); + table->view = NULL; + } - if( jv->right ) - jv->right->ops->delete( jv->right ); - jv->right = NULL; - - msi_free( jv ); + msi_free(jv); return ERROR_SUCCESS; } @@ -260,10 +245,27 @@ static UINT JOIN_find_matching_rows( struct tagMSIVIEW *view, UINT col, UINT val, UINT *row, MSIITERHANDLE *handle ) { MSIJOINVIEW *jv = (MSIJOINVIEW*)view; + UINT i, row_value; - FIXME("%p, %d, %u, %p\n", jv, col, val, *handle); + TRACE("%p, %d, %u, %p\n", view, col, val, *handle); - return ERROR_FUNCTION_FAILED; + if (col == 0 || col > jv->columns) + return ERROR_INVALID_PARAMETER; + + for (i = (UINT)*handle; i < jv->rows; i++) + { + if (view->ops->fetch_int( view, i, col, &row_value ) != ERROR_SUCCESS) + continue; + + if (row_value == val) + { + *row = i; + (*(UINT *)handle) = i + 1; + return ERROR_SUCCESS; + } + } + + return ERROR_NO_MORE_ITEMS; } static const MSIVIEWOPS join_ops = @@ -287,13 +289,14 @@ static const MSIVIEWOPS join_ops = NULL, }; -UINT JOIN_CreateView( MSIDATABASE *db, MSIVIEW **view, - LPCWSTR left, LPCWSTR right ) +UINT JOIN_CreateView( MSIDATABASE *db, MSIVIEW **view, LPWSTR tables ) { MSIJOINVIEW *jv = NULL; UINT r = ERROR_SUCCESS; + JOINTABLE *table; + LPWSTR ptr; - TRACE("%p (%s,%s)\n", jv, debugstr_w(left), debugstr_w(right) ); + TRACE("%p (%s)\n", jv, debugstr_w(tables) ); jv = msi_alloc_zero( sizeof *jv ); if( !jv ) @@ -302,35 +305,42 @@ UINT JOIN_CreateView( MSIDATABASE *db, MSIVIEW **view, /* fill the structure */ jv->view.ops = &join_ops; jv->db = db; + jv->columns = 0; + jv->rows = 0; - /* create the tables to join */ - r = TABLE_CreateView( db, left, &jv->left ); - if( r != ERROR_SUCCESS ) - { - ERR("can't create left table\n"); - goto end; - } + list_init(&jv->tables); - r = TABLE_CreateView( db, right, &jv->right ); - if( r != ERROR_SUCCESS ) + while (*tables) { - ERR("can't create right table\n"); - goto end; - } + if ((ptr = strchrW(tables, ' '))) + *ptr = '\0'; - /* get the number of columns in each table */ - r = jv->left->ops->get_dimensions( jv->left, NULL, &jv->left_count ); - if( r != ERROR_SUCCESS ) - { - ERR("can't get left table dimensions\n"); - goto end; - } + table = msi_alloc(sizeof(JOINTABLE)); + if (!table) + return ERROR_OUTOFMEMORY; - r = jv->right->ops->get_dimensions( jv->right, NULL, &jv->right_count ); - if( r != ERROR_SUCCESS ) - { - ERR("can't get right table dimensions\n"); - goto end; + r = TABLE_CreateView( db, tables, &table->view ); + if( r != ERROR_SUCCESS ) + { + ERR("can't create table\n"); + goto end; + } + + r = table->view->ops->get_dimensions( table->view, NULL, &table->columns ); + if( r != ERROR_SUCCESS ) + { + ERR("can't get table dimensions\n"); + goto end; + } + + jv->columns += table->columns; + + list_add_head( &jv->tables, &table->entry ); + + if (!ptr) + break; + + tables = ptr + 1; } *view = &jv->view; diff --git a/dlls/msi/query.h b/dlls/msi/query.h index 268989d2c34..4cf47c7a657 100644 --- a/dlls/msi/query.h +++ b/dlls/msi/query.h @@ -119,8 +119,7 @@ UINT UPDATE_CreateView( MSIDATABASE *db, MSIVIEW **view, LPCWSTR table, UINT DELETE_CreateView( MSIDATABASE *db, MSIVIEW **view, MSIVIEW *table ); -UINT JOIN_CreateView( MSIDATABASE *db, MSIVIEW **view, - LPCWSTR left, LPCWSTR right ); +UINT JOIN_CreateView( MSIDATABASE *db, MSIVIEW **view, LPWSTR tables ); UINT ALTER_CreateView( MSIDATABASE *db, MSIVIEW **view, LPCWSTR name, column_info *colinfo, int hold ); diff --git a/dlls/msi/sql.y b/dlls/msi/sql.y index f5e1cf936eb..59a67b50f4d 100644 --- a/dlls/msi/sql.y +++ b/dlls/msi/sql.y @@ -53,6 +53,7 @@ static LPWSTR SQL_getstring( void *info, const struct sql_str *str ); static INT SQL_getint( void *info ); static int sql_lex( void *SQL_lval, SQL_input *info ); +static LPWSTR parser_add_table( LPWSTR list, LPWSTR table ); static void *parser_alloc( void *info, unsigned int sz ); static column_info *parser_alloc_column( void *info, LPCWSTR table, LPCWSTR column ); @@ -101,7 +102,7 @@ static struct expr * EXPR_wildcard( void *info ); %nonassoc END_OF_FILE ILLEGAL SPACE UNCLOSED_STRING COMMENT FUNCTION COLUMN AGG_FUNCTION. -%type table id +%type table tablelist id %type selcollist column column_and_type column_def table_def %type column_assignment update_assign_list constlist %type query from fromtable selectfrom unorderedsel @@ -466,18 +467,32 @@ fromtable: if( r != ERROR_SUCCESS || !$$ ) YYABORT; } - | TK_FROM table TK_COMMA table + | TK_FROM tablelist { SQL_input* sql = (SQL_input*) info; UINT r; - /* only support inner joins on two tables */ - r = JOIN_CreateView( sql->db, &$$, $2, $4 ); + r = JOIN_CreateView( sql->db, &$$, $2 ); + msi_free( $2 ); if( r != ERROR_SUCCESS ) YYABORT; } ; +tablelist: + table + { + $$ = strdupW($1); + } + | + table TK_COMMA tablelist + { + $$ = parser_add_table($3, $1); + if (!$$) + YYABORT; + } + ; + expr: TK_LP expr TK_RP { @@ -663,6 +678,19 @@ number: %% +static LPWSTR parser_add_table(LPWSTR list, LPWSTR table) +{ + DWORD size = lstrlenW(list) + lstrlenW(table) + 2; + static const WCHAR space[] = {' ',0}; + + list = msi_realloc(list, size * sizeof(WCHAR)); + if (!list) return NULL; + + lstrcatW(list, space); + lstrcatW(list, table); + return list; +} + static void *parser_alloc( void *info, unsigned int sz ) { SQL_input* sql = (SQL_input*) info; diff --git a/dlls/msi/tests/db.c b/dlls/msi/tests/db.c index a857dbf9251..293cf0cf9f7 100644 --- a/dlls/msi/tests/db.c +++ b/dlls/msi/tests/db.c @@ -2890,10 +2890,7 @@ static void test_join(void) ok( r == ERROR_SUCCESS, "failed to open view: %d\n", r ); r = MsiViewExecute(hview, 0); - todo_wine - { - ok( r == ERROR_SUCCESS, "failed to execute view: %d\n", r ); - } + ok( r == ERROR_SUCCESS, "failed to execute view: %d\n", r ); i = 0; data_correct = TRUE; @@ -2919,10 +2916,7 @@ static void test_join(void) } ok( data_correct, "data returned in the wrong order\n"); - todo_wine - { - ok( i == 6, "Expected 6 rows, got %d\n", i ); - } + ok( i == 6, "Expected 6 rows, got %d\n", i ); ok( r == ERROR_NO_MORE_ITEMS, "expected no more items: %d\n", r ); MsiViewClose(hview); @@ -3000,7 +2994,7 @@ static void test_join(void) MsiCloseHandle(hrec); } - todo_wine ok( data_correct, "data returned in the wrong order\n"); + ok( data_correct, "data returned in the wrong order\n"); ok( i == 6, "Expected 6 rows, got %d\n", i ); ok( r == ERROR_NO_MORE_ITEMS, "expected no more items: %d\n", r ); @@ -3048,7 +3042,7 @@ static void test_join(void) i++; MsiCloseHandle(hrec); } - todo_wine ok( data_correct, "data returned in the wrong order\n"); + ok( data_correct, "data returned in the wrong order\n"); ok( i == 6, "Expected 6 rows, got %d\n", i ); ok( r == ERROR_NO_MORE_ITEMS, "expected no more items: %d\n", r ); @@ -3058,10 +3052,10 @@ static void test_join(void) query = "SELECT * FROM `One`, `Two`, `Three` "; r = MsiDatabaseOpenView(hdb, query, &hview); - todo_wine ok( r == ERROR_SUCCESS, "failed to open view: %d\n", r ); + ok( r == ERROR_SUCCESS, "failed to open view: %d\n", r ); r = MsiViewExecute(hview, 0); - todo_wine ok( r == ERROR_SUCCESS, "failed to execute view: %d\n", r ); + ok( r == ERROR_SUCCESS, "failed to execute view: %d\n", r ); i = 0; data_correct = TRUE; @@ -3099,11 +3093,8 @@ static void test_join(void) } ok( data_correct, "data returned in the wrong order\n"); - todo_wine - { - ok( i == 6, "Expected 6 rows, got %d\n", i ); - ok( r == ERROR_NO_MORE_ITEMS, "expected no more items: %d\n", r ); - } + ok( i == 6, "Expected 6 rows, got %d\n", i ); + ok( r == ERROR_NO_MORE_ITEMS, "expected no more items: %d\n", r ); MsiViewClose(hview); MsiCloseHandle(hview);