From ace54f7ed21fd0b67b87c93f9ff97b478a03d9ee Mon Sep 17 00:00:00 2001 From: Martin Willey Date: Thu, 10 Mar 2016 19:57:07 +0100 Subject: [PATCH] Hook up ProviderRepository to FactoryTools --- .../Utilities/DbProviderFactoryRepository.cs | 65 +++++++------ .../Utilities/FactoryTools.cs | 97 ++++++++++--------- .../DbProviderFactoryRepositoryTest.cs | 64 +++++++++++- 3 files changed, 142 insertions(+), 84 deletions(-) diff --git a/DatabaseSchemaReader/Utilities/DbProviderFactoryRepository.cs b/DatabaseSchemaReader/Utilities/DbProviderFactoryRepository.cs index cab5ea46..b857340b 100644 --- a/DatabaseSchemaReader/Utilities/DbProviderFactoryRepository.cs +++ b/DatabaseSchemaReader/Utilities/DbProviderFactoryRepository.cs @@ -13,17 +13,17 @@ namespace DatabaseSchemaReader.Utilities ///declared at app.config or machine.config. Basically extracted from ///http://sandrinodimattia.net/dbproviderfactoryrepository-managing-dbproviderfactories-in-code/ /// - public static class DbProviderFactoryRepository + public class DbProviderFactoryRepository { /// ///The table containing all the data. /// - private static DataTable _dbProviderFactoryTable; + private DataTable _dbProviderFactoryTable; /// ///Initialize the repository. /// - static DbProviderFactoryRepository() + public DbProviderFactoryRepository() { LoadDbProviderFactories(); } @@ -32,7 +32,7 @@ static DbProviderFactoryRepository() ///Gets all providers. /// /// - public static IEnumerable GetAllDescriptions() + public IEnumerable GetAllDescriptions() { return _dbProviderFactoryTable.Rows.Cast().Select(o => new DbProviderFactoryDescription(o)); } @@ -42,7 +42,7 @@ public static IEnumerable GetAllDescriptions() /// /// /// - public static DbProviderFactoryDescription GetDescriptionByInvariant(string invariant) + public DbProviderFactoryDescription GetDescriptionByInvariant(string invariant) { var row = _dbProviderFactoryTable.Rows.Cast() @@ -55,7 +55,7 @@ public static DbProviderFactoryDescription GetDescriptionByInvariant(string inva /// /// The description. /// - public static DbProviderFactory GetFactory(DbProviderFactoryDescription description) + public DbProviderFactory GetFactory(DbProviderFactoryDescription description) { var providerType = //AssemblyHelper.LoadTypeFrom(description.AssemblyQualifiedName); Type.GetType(description.AssemblyQualifiedName); @@ -81,7 +81,7 @@ public static DbProviderFactory GetFactory(DbProviderFactoryDescription descript /// /// The invariant. /// - public static DbProviderFactory GetFactory(string invariant) + public DbProviderFactory GetFactory(string invariant) { if (string.IsNullOrEmpty(invariant)) { @@ -98,7 +98,7 @@ public static DbProviderFactory GetFactory(string invariant) /// The path. /// /// $Path does not {path} exist. - public static void LoadExternalDbProviderAssemblies(string path) + public void LoadExternalDbProviderAssemblies(string path) { LoadExternalDbProviderAssemblies(path, true); } @@ -110,7 +110,7 @@ public static void LoadExternalDbProviderAssemblies(string path) /// if set to true [include subfolders]. /// /// $Path does not {path} exist. - public static void LoadExternalDbProviderAssemblies(string path, bool includeSubfolders) + public void LoadExternalDbProviderAssemblies(string path, bool includeSubfolders) { if (string.IsNullOrEmpty(path)) { @@ -180,30 +180,11 @@ public static void LoadExternalDbProviderAssemblies(string path, bool includeSub } } - /// - /// Gets the loadable types. - /// - /// The assembly. - /// - /// assembly - public static IEnumerable GetLoadableTypes(this Assembly assembly) - { - if (assembly == null) throw new ArgumentNullException("assembly"); - try - { - return assembly.GetTypes(); - } - catch (ReflectionTypeLoadException e) - { - return e.Types.Where(t => t != null); - } - } - /// ///Adds the specified provider. /// /// The provider. - public static void Add(DbProviderFactoryDescription provider) + public void Add(DbProviderFactoryDescription provider) { Delete(provider); _dbProviderFactoryTable.Rows.Add(provider.Name, provider.Description, provider.InvariantName, provider.AssemblyQualifiedName); @@ -213,7 +194,7 @@ public static void Add(DbProviderFactoryDescription provider) ///Deletes the specified provider if present. /// /// The provider. - private static void Delete(DbProviderFactoryDescription provider) + private void Delete(DbProviderFactoryDescription provider) { var row = _dbProviderFactoryTable.Rows.Cast() @@ -227,9 +208,31 @@ private static void Delete(DbProviderFactoryDescription provider) /// ///Opens the table. /// - private static void LoadDbProviderFactories() + private void LoadDbProviderFactories() { _dbProviderFactoryTable = DbProviderFactories.GetFactoryClasses(); } } + + internal static class AssemblyExtensions + { + /// + /// Gets the loadable types. + /// + /// The assembly. + /// + /// assembly + public static IEnumerable GetLoadableTypes(this Assembly assembly) + { + if (assembly == null) throw new ArgumentNullException("assembly"); + try + { + return assembly.GetTypes(); + } + catch (ReflectionTypeLoadException e) + { + return e.Types.Where(t => t != null); + } + } + } } \ No newline at end of file diff --git a/DatabaseSchemaReader/Utilities/FactoryTools.cs b/DatabaseSchemaReader/Utilities/FactoryTools.cs index 358a86e7..7b3b9fb8 100644 --- a/DatabaseSchemaReader/Utilities/FactoryTools.cs +++ b/DatabaseSchemaReader/Utilities/FactoryTools.cs @@ -1,48 +1,49 @@ -using System; -using System.Data; -using System.Data.Common; - -namespace DatabaseSchemaReader.Utilities -{ - /// - /// Tools to help with DbProviderFactory - /// - public static class FactoryTools - { - private static DbProviderFactory _manualProviderFactory; - - /// - /// Finds the factory. - /// - /// Name of the provider. - /// - public static DbProviderFactory GetFactory(string providerName) - { - //a simple static manual override. - if (_manualProviderFactory != null) return _manualProviderFactory; - return DbProviderFactories.GetFactory(providerName); - } - - - /// - /// Adds an existing factory. Call this before creating the DatabaseReader or SchemaReader. Use with care! - /// - /// The factory. - /// schemaReader - public static void AddFactory(DbProviderFactory factory) - { - if (factory == null) throw new ArgumentNullException("factory"); - _manualProviderFactory = factory; - } - - - /// - /// List of all the valid Providers. Use the ProviderInvariantName to fill ProviderName property - /// - /// - public static DataTable Providers() - { - return DbProviderFactories.GetFactoryClasses(); - } - } -} +using System; +using System.Data; +using System.Data.Common; + +namespace DatabaseSchemaReader.Utilities +{ + /// + /// Tools to help with DbProviderFactory + /// + public static class FactoryTools + { + /// + /// Finds the factory. You can override with (simple) or + /// + /// Name of the provider. + /// + public static DbProviderFactory GetFactory(string providerName) + { + //a simple static manual override. + if (SingleProviderFactory != null) return SingleProviderFactory; + if (ProviderRepository != null) return ProviderRepository.GetFactory(providerName); + return DbProviderFactories.GetFactory(providerName); + } + + + /// + /// Adds an existing factory. Call this before creating the DatabaseReader or SchemaReader. Use with care! + /// + public static DbProviderFactory SingleProviderFactory { get; set; } + + /// + /// Gets or sets a provider repository. + /// + /// + /// The provider repository. + /// + public static DbProviderFactoryRepository ProviderRepository { get; set; } + + + /// + /// List of all the valid Providers. Use the ProviderInvariantName to fill ProviderName property + /// + /// + public static DataTable Providers() + { + return DbProviderFactories.GetFactoryClasses(); + } + } +} diff --git a/DatabaseSchemaReaderTest/Utilities/DbProviderFactoryRepositoryTest.cs b/DatabaseSchemaReaderTest/Utilities/DbProviderFactoryRepositoryTest.cs index 6b6818b7..5d023f73 100644 --- a/DatabaseSchemaReaderTest/Utilities/DbProviderFactoryRepositoryTest.cs +++ b/DatabaseSchemaReaderTest/Utilities/DbProviderFactoryRepositoryTest.cs @@ -1,4 +1,7 @@ using System; +using System.Data.Common; +using System.Data.SqlClient; +using DatabaseSchemaReader; using DatabaseSchemaReader.Utilities; #if !NUNIT using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -6,6 +9,7 @@ using NUnit.Framework; using TestClass = NUnit.Framework.TestFixtureAttribute; using TestMethod = NUnit.Framework.TestAttribute; +using TestCleanup = NUnit.Framework.TearDownAttribute; #endif namespace DatabaseSchemaReaderTest.Utilities @@ -26,22 +30,72 @@ public void TestRepository() }; // Initialize the repository. - DbProviderFactoryRepository.Add(manualDescription); + var repo = new DbProviderFactoryRepository(); + repo.Add(manualDescription); - var descs = DbProviderFactoryRepository.GetAllDescriptions(); + var descs = repo.GetAllDescriptions(); foreach (var description in descs) { //get the description individually - var desc = DbProviderFactoryRepository.GetDescriptionByInvariant(description.InvariantName); + var desc = repo.GetDescriptionByInvariant(description.InvariantName); Assert.AreEqual(description.AssemblyQualifiedName, desc.AssemblyQualifiedName); //get a factory - var factory = DbProviderFactoryRepository.GetFactory(desc); + var factory = repo.GetFactory(desc); //may be null if not accessible } //look in the current directory - DbProviderFactoryRepository.LoadExternalDbProviderAssemblies(Environment.CurrentDirectory); + repo.LoadExternalDbProviderAssemblies(Environment.CurrentDirectory); + } + + [TestMethod] + public void FactoryToolsTest() + { + const string providername = "System.Data.SqlClient"; + + //this is normally used + var provider = FactoryTools.GetFactory(providername); + Assert.AreEqual("System.Data.SqlClient.SqlClientFactory", provider.GetType().FullName, "No override, returns SqlClient"); + + //override with a repository + FactoryTools.ProviderRepository = new DbProviderFactoryRepository(); + var manualDescription = new DbProviderFactoryDescription + { + Description = ".NET Framework Data Provider for SuperDuperDatabase", + InvariantName = "SuperDuperDatabase", + Name = "SuperDuperDatabase Data Provider", + AssemblyQualifiedName = typeof(SuperDuperProviderFactory).AssemblyQualifiedName, + }; + FactoryTools.ProviderRepository.Add(manualDescription); + + provider = FactoryTools.GetFactory(providername); + Assert.AreEqual("System.Data.SqlClient.SqlClientFactory", provider.GetType().FullName, "Overridden, but returns underlying SqlClient"); + provider = FactoryTools.GetFactory("SuperDuperDatabase"); + Assert.AreEqual(typeof(SuperDuperProviderFactory), provider.GetType(), "Overridden, returns manually added provider"); + + //override with a single provider + FactoryTools.SingleProviderFactory = SqlClientFactory.Instance; + provider = FactoryTools.GetFactory("Xxxx"); + Assert.AreEqual("System.Data.SqlClient.SqlClientFactory", provider.GetType().FullName, "Overridden, always returns SqlClient"); + + var dr = new DatabaseReader(ConnectionStrings.Northwind, "Xxxxx"); + var tables = dr.TableList(); + + Assert.IsTrue(tables.Count > 0, "We called the reader with a bogus provider type, but we got the overridden type"); + } + + [TestCleanup] + public void CleanUp() + { + //reset the overrides + FactoryTools.ProviderRepository = null; + FactoryTools.SingleProviderFactory = null; + } + + public class SuperDuperProviderFactory : DbProviderFactory + { + public static SuperDuperProviderFactory Instance = new SuperDuperProviderFactory(); } } } \ No newline at end of file