From ef445cd45239b0fa2189947f973419df1f168ad8 Mon Sep 17 00:00:00 2001 From: Andrew Mattie Date: Fri, 25 Oct 2013 09:45:11 -0700 Subject: [PATCH] Refactored IEnumerable<> calls to use 'yield return' for streaming for large files --- .../Extensions/CommonExtensions.cs | 4 +- src/LinqToExcel/Query/ExcelQueryExecutor.cs | 86 ++++++++++--------- 2 files changed, 47 insertions(+), 43 deletions(-) diff --git a/src/LinqToExcel/Extensions/CommonExtensions.cs b/src/LinqToExcel/Extensions/CommonExtensions.cs index f9484ba..f280432 100644 --- a/src/LinqToExcel/Extensions/CommonExtensions.cs +++ b/src/LinqToExcel/Extensions/CommonExtensions.cs @@ -52,10 +52,8 @@ public static object Cast(this object @object, Type castType) public static IEnumerable Cast(this IEnumerable list, Func caster) { - var results = new List(); foreach (var item in list) - results.Add(caster(item)); - return results; + yield return caster(item); } public static IEnumerable Cast(this IEnumerable list) diff --git a/src/LinqToExcel/Query/ExcelQueryExecutor.cs b/src/LinqToExcel/Query/ExcelQueryExecutor.cs index 0fc53bd..b7ea05a 100644 --- a/src/LinqToExcel/Query/ExcelQueryExecutor.cs +++ b/src/LinqToExcel/Query/ExcelQueryExecutor.cs @@ -164,34 +164,34 @@ protected IEnumerable GetDataResults(SqlParts sql, QueryModel queryModel { IEnumerable results; OleDbDataReader data = null; - using (var conn = new OleDbConnection(_connectionString)) - using (var command = conn.CreateCommand()) + var conn = new OleDbConnection(_connectionString); + var command = conn.CreateCommand(); + + conn.Open(); + command.CommandText = sql.ToString(); + command.Parameters.AddRange(sql.Parameters.ToArray()); + try { data = command.ExecuteReader(); } + catch (OleDbException e) { - conn.Open(); - command.CommandText = sql.ToString(); - command.Parameters.AddRange(sql.Parameters.ToArray()); - try { data = command.ExecuteReader(); } - catch (OleDbException e) - { - if (e.Message.Contains(_args.WorksheetName)) - throw new DataException( - string.Format("'{0}' is not a valid worksheet name. Valid worksheet names are: '{1}'", - _args.WorksheetName, string.Join("', '", ExcelUtilities.GetWorksheetNames(_args.FileName).ToArray()))); - if (!CheckIfInvalidColumnNameUsed(sql)) - throw e; - } - - var columns = ExcelUtilities.GetColumnNames(data); - LogColumnMappingWarnings(columns); - if (columns.Count() == 1 && columns.First() == "Expr1000") - results = GetScalarResults(data); - else if (queryModel.MainFromClause.ItemType == typeof(Row)) - results = GetRowResults(data, columns); - else if (queryModel.MainFromClause.ItemType == typeof(RowNoHeader)) - results = GetRowNoHeaderResults(data); - else - results = GetTypeResults(data, columns, queryModel); + if (e.Message.Contains(_args.WorksheetName)) + throw new DataException( + string.Format("'{0}' is not a valid worksheet name. Valid worksheet names are: '{1}'", + _args.WorksheetName, string.Join("', '", ExcelUtilities.GetWorksheetNames(_args.FileName).ToArray()))); + if (!CheckIfInvalidColumnNameUsed(sql)) + throw e; } + + var columns = ExcelUtilities.GetColumnNames(data); + LogColumnMappingWarnings(columns); + if (columns.Count() == 1 && columns.First() == "Expr1000") + results = GetScalarResults(data, conn, command); + else if (queryModel.MainFromClause.ItemType == typeof(Row)) + results = GetRowResults(data, columns, conn, command); + else if (queryModel.MainFromClause.ItemType == typeof(RowNoHeader)) + results = GetRowNoHeaderResults(data, conn, command); + else + results = GetTypeResults(data, columns, queryModel, conn, command); + return results; } @@ -229,9 +229,8 @@ private bool CheckIfInvalidColumnNameUsed(SqlParts sql) return false; } - private IEnumerable GetRowResults(IDataReader data, IEnumerable columns) + private IEnumerable GetRowResults(IDataReader data, IEnumerable columns, OleDbConnection conn, OleDbCommand command) { - var results = new List(); var columnIndexMapping = new Dictionary(); for (var i = 0; i < columns.Count(); i++) columnIndexMapping[columns.ElementAt(i)] = i; @@ -241,27 +240,29 @@ private IEnumerable GetRowResults(IDataReader data, IEnumerable IList cells = new List(); for (var i = 0; i < columns.Count(); i++) cells.Add(new Cell(data[i])); - results.CallMethod("Add", new Row(cells, columnIndexMapping)); + yield return new Row(cells, columnIndexMapping); } - return results.AsEnumerable(); + + conn.Dispose(); + command.Dispose(); } - private IEnumerable GetRowNoHeaderResults(OleDbDataReader data) + private IEnumerable GetRowNoHeaderResults(OleDbDataReader data, OleDbConnection conn, OleDbCommand command) { - var results = new List(); while (data.Read()) { IList cells = new List(); for (var i = 0; i < data.FieldCount; i++) cells.Add(new Cell(data[i])); - results.CallMethod("Add", new RowNoHeader(cells)); + yield return new RowNoHeader(cells); } - return results.AsEnumerable(); + + conn.Dispose(); + command.Dispose(); } - private IEnumerable GetTypeResults(IDataReader data, IEnumerable columns, QueryModel queryModel) + private IEnumerable GetTypeResults(IDataReader data, IEnumerable columns, QueryModel queryModel, OleDbConnection conn, OleDbCommand command) { - var results = new List(); var fromType = queryModel.MainFromClause.ItemType; var props = fromType.GetProperties(); if (_args.StrictMapping.Value != StrictMappingType.None) @@ -278,9 +279,11 @@ private IEnumerable GetTypeResults(IDataReader data, IEnumerable if (columns.Contains(columnName)) result.SetProperty(prop.Name, GetColumnValue(data, columnName, prop.Name).Cast(prop.PropertyType)); } - results.Add(result); + yield return result; } - return results.AsEnumerable(); + + conn.Dispose(); + command.Dispose(); } private void ConfirmStrictMapping(IEnumerable columns, PropertyInfo[] properties, StrictMappingType strictMappingType) @@ -323,10 +326,13 @@ private object GetColumnValue(IDataRecord data, string columnName, string proper data[columnName]; } - private IEnumerable GetScalarResults(IDataReader data) + private IEnumerable GetScalarResults(IDataReader data, OleDbConnection conn, OleDbCommand command) { data.Read(); - return new List { data[0] }; + yield return data[0]; + + conn.Dispose(); + command.Dispose(); } private void LogSqlStatement(SqlParts sqlParts)