Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored IEnumerable<> calls to use 'yield return' for streaming #27

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/LinqToExcel/Extensions/CommonExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ public static object Cast(this object @object, Type castType)

public static IEnumerable<TResult> Cast<TResult>(this IEnumerable<object> list, Func<object, TResult> caster)
{
var results = new List<TResult>();
foreach (var item in list)
results.Add(caster(item));
return results;
yield return caster(item);
}

public static IEnumerable<TResult> Cast<TResult>(this IEnumerable<object> list)
Expand Down
86 changes: 46 additions & 40 deletions src/LinqToExcel/Query/ExcelQueryExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,34 +164,34 @@ protected IEnumerable<object> GetDataResults(SqlParts sql, QueryModel queryModel
{
IEnumerable<object> results;
OleDbDataReader data = null;
using (var conn = new OleDbConnection(_connectionString))
using (var command = conn.CreateCommand())
var conn = new OleDbConnection(_connectionString);
var command = conn.CreateCommand();

conn.Open();
command.CommandText = sql.ToString();
command.Parameters.AddRange(sql.Parameters.ToArray());
try { data = command.ExecuteReader(); }
catch (OleDbException e)
{
conn.Open();
command.CommandText = sql.ToString();
command.Parameters.AddRange(sql.Parameters.ToArray());
try { data = command.ExecuteReader(); }
catch (OleDbException e)
{
if (e.Message.Contains(_args.WorksheetName))
throw new DataException(
string.Format("'{0}' is not a valid worksheet name. Valid worksheet names are: '{1}'",
_args.WorksheetName, string.Join("', '", ExcelUtilities.GetWorksheetNames(_args.FileName).ToArray())));
if (!CheckIfInvalidColumnNameUsed(sql))
throw e;
}

var columns = ExcelUtilities.GetColumnNames(data);
LogColumnMappingWarnings(columns);
if (columns.Count() == 1 && columns.First() == "Expr1000")
results = GetScalarResults(data);
else if (queryModel.MainFromClause.ItemType == typeof(Row))
results = GetRowResults(data, columns);
else if (queryModel.MainFromClause.ItemType == typeof(RowNoHeader))
results = GetRowNoHeaderResults(data);
else
results = GetTypeResults(data, columns, queryModel);
if (e.Message.Contains(_args.WorksheetName))
throw new DataException(
string.Format("'{0}' is not a valid worksheet name. Valid worksheet names are: '{1}'",
_args.WorksheetName, string.Join("', '", ExcelUtilities.GetWorksheetNames(_args.FileName).ToArray())));
if (!CheckIfInvalidColumnNameUsed(sql))
throw e;
}

var columns = ExcelUtilities.GetColumnNames(data);
LogColumnMappingWarnings(columns);
if (columns.Count() == 1 && columns.First() == "Expr1000")
results = GetScalarResults(data, conn, command);
else if (queryModel.MainFromClause.ItemType == typeof(Row))
results = GetRowResults(data, columns, conn, command);
else if (queryModel.MainFromClause.ItemType == typeof(RowNoHeader))
results = GetRowNoHeaderResults(data, conn, command);
else
results = GetTypeResults(data, columns, queryModel, conn, command);

return results;
}

Expand Down Expand Up @@ -229,9 +229,8 @@ private bool CheckIfInvalidColumnNameUsed(SqlParts sql)
return false;
}

private IEnumerable<object> GetRowResults(IDataReader data, IEnumerable<string> columns)
private IEnumerable<object> GetRowResults(IDataReader data, IEnumerable<string> columns, OleDbConnection conn, OleDbCommand command)
{
var results = new List<object>();
var columnIndexMapping = new Dictionary<string, int>();
for (var i = 0; i < columns.Count(); i++)
columnIndexMapping[columns.ElementAt(i)] = i;
Expand All @@ -241,27 +240,29 @@ private IEnumerable<object> GetRowResults(IDataReader data, IEnumerable<string>
IList<Cell> cells = new List<Cell>();
for (var i = 0; i < columns.Count(); i++)
cells.Add(new Cell(data[i]));
results.CallMethod("Add", new Row(cells, columnIndexMapping));
yield return new Row(cells, columnIndexMapping);
}
return results.AsEnumerable();

conn.Dispose();
command.Dispose();
}

private IEnumerable<object> GetRowNoHeaderResults(OleDbDataReader data)
private IEnumerable<object> GetRowNoHeaderResults(OleDbDataReader data, OleDbConnection conn, OleDbCommand command)
{
var results = new List<object>();
while (data.Read())
{
IList<Cell> cells = new List<Cell>();
for (var i = 0; i < data.FieldCount; i++)
cells.Add(new Cell(data[i]));
results.CallMethod("Add", new RowNoHeader(cells));
yield return new RowNoHeader(cells);
}
return results.AsEnumerable();

conn.Dispose();
command.Dispose();
}

private IEnumerable<object> GetTypeResults(IDataReader data, IEnumerable<string> columns, QueryModel queryModel)
private IEnumerable<object> GetTypeResults(IDataReader data, IEnumerable<string> columns, QueryModel queryModel, OleDbConnection conn, OleDbCommand command)
{
var results = new List<object>();
var fromType = queryModel.MainFromClause.ItemType;
var props = fromType.GetProperties();
if (_args.StrictMapping.Value != StrictMappingType.None)
Expand All @@ -278,9 +279,11 @@ private IEnumerable<object> GetTypeResults(IDataReader data, IEnumerable<string>
if (columns.Contains(columnName))
result.SetProperty(prop.Name, GetColumnValue(data, columnName, prop.Name).Cast(prop.PropertyType));
}
results.Add(result);
yield return result;
}
return results.AsEnumerable();

conn.Dispose();
command.Dispose();
}

private void ConfirmStrictMapping(IEnumerable<string> columns, PropertyInfo[] properties, StrictMappingType strictMappingType)
Expand Down Expand Up @@ -323,10 +326,13 @@ private object GetColumnValue(IDataRecord data, string columnName, string proper
data[columnName];
}

private IEnumerable<object> GetScalarResults(IDataReader data)
private IEnumerable<object> GetScalarResults(IDataReader data, OleDbConnection conn, OleDbCommand command)
{
data.Read();
return new List<object> { data[0] };
yield return data[0];

conn.Dispose();
command.Dispose();
}

private void LogSqlStatement(SqlParts sqlParts)
Expand Down