diff --git a/README.md b/README.md index 5a03d13b..c9fb0520 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,7 @@ Short overview about why anyone would use this, how it came to be (even shorter) The recommended way to use the latest version of Sigma is adding the NuGet package to your project. You can either include the core framework (command line only) [![Nuget (PreRelease)](https://img.shields.io/nuget/vpre/Sigma.Core.svg?style=flat-square)](https://www.nuget.org/packages/Sigma.Core) or the WPF visualiser (only works on Windows) which also references the core framework [![Nuget (PreRelease WPF)](https://img.shields.io/nuget/vpre/Sigma.Core.Monitors.WPF.svg?style=flat-square)](https://www.nuget.org/packages/Sigma.Core.Monitors.WPF). -In both cases, you can use any project with a main (ConsoleApplication) but you have to change the project settings to x64 since **Sigma only supports 64bit mode**. - +In both cases, you can use any project with a main (e.g. ConsoleApplication) but you have to change the project settings to x64 (since **Sigma only supports 64bit mode**) and change the target framework to **.NET 4.6** before installing the NuGet packages. ### From source diff --git a/Sigma.Core.Monitors.WPF/Panels/Charts/AccuracyPanel.cs b/Sigma.Core.Monitors.WPF/Panels/Charts/AccuracyPanel.cs index d480525e..a96a0db9 100644 --- a/Sigma.Core.Monitors.WPF/Panels/Charts/AccuracyPanel.cs +++ b/Sigma.Core.Monitors.WPF/Panels/Charts/AccuracyPanel.cs @@ -6,6 +6,7 @@ MIT License For full license see LICENSE in the root directory of this project. */ +using System; using System.Collections.Generic; using LiveCharts; using LiveCharts.Wpf; @@ -42,15 +43,30 @@ public AccuracyPanel(string title, ITrainer trainer, object headerContent = null /// The content for the header. If null is passed, /// the title will be used. /// - public AccuracyPanel(string title, ITrainer trainer, object headerContent = null, params int[] tops) : base(title, headerContent) + public AccuracyPanel(string title, ITrainer trainer, object headerContent = null, params int[] tops) : this(title, trainer, TimeStep.Every(1, TimeScale.Epoch), headerContent, tops) { + } + + /// + /// Create an AccuracyPanel with a given title. It displays given accuracies per epoch. + /// If a title is not sufficient modify . + /// + /// The given tile. + /// + /// The content for the header. If null is passed, + /// the title will be used. + /// + public AccuracyPanel(string title, ITrainer trainer, ITimeStep timeStep, object headerContent = null, params int[] tops) : base(title, headerContent) + { + if (timeStep == null) throw new ArgumentNullException(nameof(timeStep)); + // skip the first since its automatically generated for (int i = 1; i < tops.Length; i++) { AddSeries(new LineSeries()); } - trainer.AddHook(new ChartValidationAccuracyReport(this, "validation", TimeStep.Every(1, TimeScale.Epoch), tops)); + trainer.AddHook(new ChartValidationAccuracyReport(this, "validation", timeStep, tops)); AxisY.MinValue = 0; AxisY.MaxValue = 100; diff --git a/Sigma.Core.Monitors.WPF/Panels/Charts/TrainerChartPanel.cs b/Sigma.Core.Monitors.WPF/Panels/Charts/TrainerChartPanel.cs index 333be572..7ac2486d 100644 --- a/Sigma.Core.Monitors.WPF/Panels/Charts/TrainerChartPanel.cs +++ b/Sigma.Core.Monitors.WPF/Panels/Charts/TrainerChartPanel.cs @@ -48,9 +48,9 @@ namespace Sigma.Core.Monitors.WPF.Panels.Charts /// The for the hook. /// The content for the header. If null is passed, /// the title will be used. - public TrainerChartPanel(string title, ITrainer trainer, string hookedValue, ITimeStep timestep, object headerContent = null) : base(title, headerContent) + public TrainerChartPanel(string title, ITrainer trainer, string hookedValue, ITimeStep timestep, bool averageMode = false, object headerContent = null) : base(title, headerContent) { - VisualValueReporterHook hook = new VisualValueReporterHook(this, new[] { hookedValue }, timestep); + VisualValueReporterHook hook = new VisualValueReporterHook(this, new[] { hookedValue }, timestep, averageMode); Init(trainer, hook); } @@ -65,9 +65,9 @@ public TrainerChartPanel(string title, ITrainer trainer, string hookedValue, ITi /// The for the hook. /// The content for the header. If null is passed, /// the title will be used. - public TrainerChartPanel(string title, ITrainer trainer, ITimeStep timestep, object headerContent = null, params string[] hookedValues) : base(title, headerContent) + public TrainerChartPanel(string title, ITrainer trainer, ITimeStep timestep, bool averageMode = false, object headerContent = null, params string[] hookedValues) : base(title, headerContent) { - VisualValueReporterHook hook = new VisualValueReporterHook(this, hookedValues, timestep); + VisualValueReporterHook hook = new VisualValueReporterHook(this, hookedValues, timestep, averageMode); Init(trainer, hook); } @@ -118,7 +118,7 @@ protected class VisualValueReporterHook : ValueReporterHook /// The chartpanel to which points will get added. /// The identifiers for the ; these values will get plotted. /// The for the hook (i.e. execution definition). - public VisualValueReporterHook(ChartPanel chartPanel, string[] valueIdentifiers, ITimeStep timestep) : base(valueIdentifiers, timestep) + public VisualValueReporterHook(ChartPanel chartPanel, string[] valueIdentifiers, ITimeStep timestep, bool averageMode = false) : base(valueIdentifiers, timestep, averageMode, false) { ParameterRegistry[ChartPanelIdentifier] = chartPanel; } @@ -127,7 +127,10 @@ public VisualValueReporterHook(ChartPanel /// Report the values for a certain epoch / iteration to a passed ChartPanel. /// /// The values by their identifier. - protected override void ReportValues(IDictionary valuesByIdentifier) + /// A boolean indicating whether or not to report the current epoch / iteration. + /// The current epoch. + /// The current iteration. + protected override void ReportValues(IDictionary valuesByIdentifier, bool reportEpochIteration, int epoch, int iteration) { ChartPanel chartPanel = (ChartPanel) ParameterRegistry[ChartPanelIdentifier]; chartPanel.Add((TData) valuesByIdentifier.Values.First()); diff --git a/Sigma.Core.Monitors.WPF/Panels/Controls/ControlPanel.cs b/Sigma.Core.Monitors.WPF/Panels/Controls/ControlPanel.cs index ca4f677d..f5790190 100644 --- a/Sigma.Core.Monitors.WPF/Panels/Controls/ControlPanel.cs +++ b/Sigma.Core.Monitors.WPF/Panels/Controls/ControlPanel.cs @@ -6,10 +6,17 @@ MIT License For full license see LICENSE in the root directory of this project. */ -using System.Windows; -using System.Windows.Controls; using Sigma.Core.Monitors.WPF.View.CustomControls.Panels.Control; +using Sigma.Core.Monitors.WPF.View.Parameterisation; +using Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults; +using Sigma.Core.Monitors.WPF.View.Windows; using Sigma.Core.Training; +using Sigma.Core.Training.Hooks.Reporters; +using Sigma.Core.Utils; +using System; +using System.Collections.Generic; +using System.Windows; +using System.Windows.Controls; namespace Sigma.Core.Monitors.WPF.Panels.Controls { @@ -19,7 +26,8 @@ namespace Sigma.Core.Monitors.WPF.Panels.Controls /// public class ControlPanel : GenericPanel { - private readonly SigmaPlaybackControl _playbackControl; + private SigmaPlaybackControl _playbackControl; + private ParameterView _parameterView; private ITrainer _trainer; @@ -36,6 +44,18 @@ public ITrainer Trainer } } + /// + /// This list stores all trainers that have been initialised. + /// Required to only add one hook per trainer. + /// + private static readonly IList Trainers; + + static ControlPanel() + { + Trainers = new List(); + } + + public ControlPanel(string title, object content = null) : this(title, null, content) { } public ControlPanel(string title, ITrainer trainer, object content = null) : base(title, content) @@ -49,10 +69,61 @@ public ControlPanel(string title, ITrainer trainer, object content = null) : bas HorizontalAlignment = HorizontalAlignment.Center, Margin = new Thickness(0, 20, 0, 0) }; + } + + /// + /// This method will be called once the window is initialising (after it has been added). + /// Do not store a reference of the window unless you properly dispose it (remove reference once not required). + /// + /// The wpf window this panel will be added to. + protected override void OnInitialise(WPFWindow window) + { + throw new InvalidOperationException($"{nameof(ControlPanel)} is only compatible with {nameof(SigmaWindow)}s."); + } - _playbackControl = new SigmaPlaybackControl { Trainer = Trainer }; + /// + /// This method will be called after the panel has been added (window, monitor set...) + /// + protected override void OnInitialise(SigmaWindow window) + { + if (!Trainers.Contains(Trainer)) + { + ValueSourceReporterHook valueHook = new ValueSourceReporterHook(TimeStep.Every(1, TimeScale.Epoch), "runtime_millis"); + _trainer.AddGlobalHook(valueHook); + Monitor.Sigma.SynchronisationHandler.AddSynchronisationSource(valueHook); + Trainers.Add(Trainer); + + valueHook = new ValueSourceReporterHook(TimeStep.Every(1, TimeScale.Iteration), "iteration"); + _trainer.AddLocalHook(valueHook); + Monitor.Sigma.SynchronisationHandler.AddSynchronisationSource(valueHook); + } + + //TODO: style? + _playbackControl = new SigmaPlaybackControl { Trainer = Trainer, Margin = new Thickness(0, 0, 0, 20), HorizontalAlignment = HorizontalAlignment.Center}; Content.Children.Add(_playbackControl); + + _parameterView = new ParameterView(Monitor.Sigma, window); + + //TODO: language support + + SigmaTextBlock timeBox = (SigmaTextBlock) _parameterView.Add("Running time", typeof(object), _trainer.Operator.Registry, "runtime_millis"); + timeBox.AutoPollValues(_trainer, TimeStep.Every(1, TimeScale.Epoch)); + timeBox.Postfix = " ms"; + + UserControlParameterVisualiser epochBox = (UserControlParameterVisualiser) _parameterView.Add("Current epoch", typeof(object), _trainer.Operator.Registry, "epoch"); + epochBox.AutoPollValues(_trainer, TimeStep.Every(1, TimeScale.Epoch)); + + UserControlParameterVisualiser iterationBox = (UserControlParameterVisualiser) _parameterView.Add("Current iteration", typeof(object), _trainer.Operator.Registry, "iteration"); + iterationBox.AutoPollValues(_trainer, TimeStep.Every(1, TimeScale.Iteration)); + + IRegistry registry = new Registry + { + { "op", Trainer.Operator.GetType().Name } + }; + _parameterView.Add("Current operator", typeof(object), registry, "op"); + + Content.Children.Add(_parameterView); } } } diff --git a/Sigma.Core.Monitors.WPF/Panels/Controls/DrawPanel.cs b/Sigma.Core.Monitors.WPF/Panels/Controls/DrawPanel.cs index c256afdc..ffa2cd40 100644 --- a/Sigma.Core.Monitors.WPF/Panels/Controls/DrawPanel.cs +++ b/Sigma.Core.Monitors.WPF/Panels/Controls/DrawPanel.cs @@ -108,7 +108,7 @@ public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) }); DataProviderUtils.ProvideExternalInputData(provider, network, block); - network.Run(Operator.Handler, false); + network.Run(Operator.Handler, trainingPass: false); DataProviderUtils.ProvideExternalOutputData(provider, network, block); } } diff --git a/Sigma.Core.Monitors.WPF/Panels/Parameterisation/ParameterPanel.cs b/Sigma.Core.Monitors.WPF/Panels/Parameterisation/ParameterPanel.cs index edf1a8ef..ea956598 100644 --- a/Sigma.Core.Monitors.WPF/Panels/Parameterisation/ParameterPanel.cs +++ b/Sigma.Core.Monitors.WPF/Panels/Parameterisation/ParameterPanel.cs @@ -9,6 +9,7 @@ For full license see LICENSE in the root directory of this project. using System; using Sigma.Core.Monitors.Synchronisation; using Sigma.Core.Monitors.WPF.View.Parameterisation; +using Sigma.Core.Monitors.WPF.View.Windows; using Sigma.Core.Monitors.WPF.ViewModel.Parameterisation; using Sigma.Core.Utils; @@ -27,6 +28,7 @@ public ParameterPanel(string title, IParameterVisualiserManager visualiserManage { Content = new ParameterView(visualiserManager, synchronisationHandler); } + public ParameterPanel(string title, SigmaEnvironment environment, SigmaWindow window, object headerContent = null) : this(title, window.ParameterVisualiser, environment.SynchronisationHandler, headerContent) { } public void Add(string name, Type type, IRegistry registry, string key) { diff --git a/Sigma.Core.Monitors.WPF/Panels/SigmaPanel.cs b/Sigma.Core.Monitors.WPF/Panels/SigmaPanel.cs index 73028d60..5fe123e4 100644 --- a/Sigma.Core.Monitors.WPF/Panels/SigmaPanel.cs +++ b/Sigma.Core.Monitors.WPF/Panels/SigmaPanel.cs @@ -9,6 +9,7 @@ For full license see LICENSE in the root directory of this project. using System.Windows; using System.Windows.Controls; using MaterialDesignThemes.Wpf; +using Sigma.Core.Monitors.WPF.View.Windows; // ReSharper disable VirtualMemberCallInConstructor @@ -38,6 +39,11 @@ public abstract class SigmaPanel : Card /// private UIElement _content; + /// + /// Currently responsible monitor - it will be automatically set when adding a new panel. (null until ) + /// + public WPFMonitor Monitor { get; set; } + /// /// Create a SigmaPanel with a given title. /// If a title is not sufficient modify . @@ -99,7 +105,7 @@ protected SigmaPanel(string title, object content = null) } else { - _content = value as UIElement ?? new Label { Content = value.ToString() }; + _content = value as UIElement ?? new Label {Content = value.ToString()}; ContentGrid.Children.Add(_content); } } @@ -143,10 +149,10 @@ protected virtual Grid CreateHeader(object content) { Grid header = new Grid(); - header.RowDefinitions.Add(new RowDefinition { Height = new GridLength(1, GridUnitType.Auto) }); - header.ColumnDefinitions.Add(new ColumnDefinition { Width = new GridLength(1, GridUnitType.Auto) }); + header.RowDefinitions.Add(new RowDefinition {Height = new GridLength(1, GridUnitType.Auto)}); + header.ColumnDefinitions.Add(new ColumnDefinition {Width = new GridLength(1, GridUnitType.Auto)}); - Label headerContent = new Label { Content = content }; + Label headerContent = new Label {Content = content}; header.Children.Add(headerContent); header.SetResourceReference(BackgroundProperty, "SigmaPanelHeaderBackground"); @@ -174,19 +180,54 @@ protected virtual Grid CreateContentGrid() { Grid grid = new Grid(); - grid.RowDefinitions.Add(new RowDefinition { Height = new GridLength(1, GridUnitType.Star) }); - grid.ColumnDefinitions.Add(new ColumnDefinition { Width = new GridLength(1, GridUnitType.Star) }); + grid.RowDefinitions.Add(new RowDefinition {Height = new GridLength(1, GridUnitType.Star)}); + grid.ColumnDefinitions.Add(new ColumnDefinition {Width = new GridLength(1, GridUnitType.Star)}); return grid; } + /// + /// This method invokes the initialisation of the panel (after it has been addded). + /// + public void Initialise(WPFWindow window) + { + if (window is SigmaWindow) + { + OnInitialise((SigmaWindow)window); + } + else + { + OnInitialise(window); + } + } + + /// + /// This method will be called once the window is initialising (after it has been added). + /// Do not store a reference of the window unless you properly dispose it (remove reference once not required). + /// + /// The wpf window this panel will be added to. + protected virtual void OnInitialise(WPFWindow window) + { + + } + + /// + /// This method will be called once the window is initialising (after it has been added). + /// Do not store a reference of the window unless you properly dispose it (remove reference once not required). + /// + /// The wpf window this panel will be added to. + protected virtual void OnInitialise(SigmaWindow window) + { + + } + /// /// Create the default panel in which every other element is contained. /// /// The newly create . protected virtual DockPanel CreateDockPanel() { - return new DockPanel { LastChildFill = true, Margin = new Thickness(-1, 0, 0, 0) }; + return new DockPanel {LastChildFill = true, Margin = new Thickness(-1, 0, 0, 0)}; } } } \ No newline at end of file diff --git a/Sigma.Core.Monitors.WPF/Sigma.Core.Monitors.WPF.csproj b/Sigma.Core.Monitors.WPF/Sigma.Core.Monitors.WPF.csproj index 13638d38..8010611e 100644 --- a/Sigma.Core.Monitors.WPF/Sigma.Core.Monitors.WPF.csproj +++ b/Sigma.Core.Monitors.WPF/Sigma.Core.Monitors.WPF.csproj @@ -168,6 +168,7 @@ + @@ -222,6 +223,8 @@ SigmaComboBox.xaml + + SigmaSlider.xaml @@ -233,6 +236,7 @@ + ParameterView.xaml diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaDynamicGenericBox.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaDynamicGenericBox.cs new file mode 100644 index 00000000..7e11bf59 --- /dev/null +++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaDynamicGenericBox.cs @@ -0,0 +1,59 @@ +using System; +using System.ComponentModel; + +namespace Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults +{ + /// + /// This is a generic text box that automatically uses the type of the registry in order to correctly display and parse to / from the value. + /// + /// It is necessary for XAML or other cases where generics are not possible. + /// + public class SigmaDynamicGenericBox : SigmaTextBox + { + /// + /// The converter that converts the given type for the registry + /// + public TypeConverter Converter { get; protected set; } + + /// + /// The current value that is displayed + /// + public object CurrentValue { get; protected set; } + + /// + /// Force the visualiser to update its value (i.e. display the value that is stored). + /// + public override void Read() + { + object obj = SynchronisationHandler.SynchroniseGet(Registry, Key); + + if (Converter == null && obj != null) + { + Converter = TypeDescriptor.GetConverter(obj.GetType()); + } + + if (obj != null) + { + CurrentValue = obj; + Text = obj.ToString(); + } + } + + /// + /// Force the visualiser to store its value (i.e. write the value that is displayed to the registry). + /// + public override void Write() + { + try + { + object convertedVal = Converter.ConvertFromString(Text); + Pending = true; + SynchronisationHandler.SynchroniseSet(Registry, Key, convertedVal, val => Pending = false, e => Errored = true); + } + catch (Exception) + { + Errored = true; + } + } + } +} \ No newline at end of file diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaGenericBox.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaGenericBox.cs new file mode 100644 index 00000000..652d696b --- /dev/null +++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaGenericBox.cs @@ -0,0 +1,71 @@ +using System; +using System.ComponentModel; +using Sigma.Core.Monitors.WPF.ViewModel.Parameterisation; + +namespace Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults +{ + /// + /// A textbox that can convert from and to an arbitrary object with a type converter. + /// + /// The type that is currently being represented. + [GenericParameterVisualiser(typeof(double), Priority = VisualiserPriority.Lower)] + [GenericParameterVisualiser(typeof(float), Priority = VisualiserPriority.Lower)] + [GenericParameterVisualiser(typeof(short), Priority = VisualiserPriority.Lower)] + [GenericParameterVisualiser(typeof(int), Priority = VisualiserPriority.Lower)] + [GenericParameterVisualiser(typeof(long), Priority = VisualiserPriority.Lower)] + public class SigmaGenericBox : SigmaTextBox + { + /// + /// The converter that converts the given type for the registry + /// + public TypeConverter Converter { get; protected set; } + + /// + /// The current active value. + /// + protected T CurrentValue; + + /// + /// Create a generic box and initialise the converter. + /// + public SigmaGenericBox() + { + Converter = TypeDescriptor.GetConverter(typeof(T)); + } + + /// + /// Force the visualiser to update its value (i.e. display the value that is stored). + /// + public override void Read() + { + SynchronisationHandler.SynchroniseUpdate(Registry, Key, CurrentValue, val => + { + CurrentValue = val; + Text = CurrentValue.ToString(); + }); + } + + /// + /// Force the visualiser to store its value (i.e. write the value that is displayed to the registry). + /// + public override void Write() + { + try + { + T convertedValue = (T) Converter.ConvertFromString(Text); + Pending = true; + SynchronisationHandler.SynchroniseSet(Registry, Key, convertedValue, val => + { + Pending = false; + Errored = false; + }, e => Errored = true); + + } + catch (Exception) + { + Errored = true; + throw; + } + } + } +} \ No newline at end of file diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaSlider.xaml b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaSlider.xaml index e2dbc879..15d9a139 100644 --- a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaSlider.xaml +++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaSlider.xaml @@ -26,7 +26,7 @@ IsEnabled="{Binding IsReadOnly, Converter={StaticResource InverseBooleanConverter}}" VerticalAlignment="Center"/> - diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBlock.xaml.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBlock.xaml.cs index fffd80fd..a650317a 100644 --- a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBlock.xaml.cs +++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBlock.xaml.cs @@ -6,6 +6,8 @@ MIT License For full license see LICENSE in the root directory of this project. */ +using System.Windows; +using System.Windows.Controls; using Sigma.Core.Monitors.Synchronisation; using Sigma.Core.Monitors.WPF.ViewModel.Parameterisation; using Sigma.Core.Utils; @@ -18,18 +20,22 @@ namespace Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults [ParameterVisualiser(typeof(object), Priority = VisualiserPriority.Lower)] public partial class SigmaTextBlock { - private object _object; + /// + /// The object that is currently being displayed (without updating the displayed information) + /// + protected object _Object; /// /// The object that is being displayed (toString is called). /// - public object Object + public virtual object Object { - get { return _object; } + get { return _Object; } set { - _object = value; - TextBlock.Text = value?.ToString() ?? "null"; + _Object = value; + string text = value?.ToString() ?? "null"; + TextBlock.Text = Prefix + text + Postfix; } } @@ -38,6 +44,29 @@ public object Object /// public string Text => TextBlock.Text; + /// + /// This string will be added before the displayed string. + /// + public string Prefix + { + get { return (string) GetValue(PrefixProperty); } + set { SetValue(PrefixProperty, value); } + } + + public static readonly DependencyProperty PrefixProperty = + DependencyProperty.Register("Prefix", typeof(string), typeof(SigmaTextBox), new PropertyMetadata("")); + + /// + /// This string will be added after the displayed string. + /// + public string Postfix + { + get { return (string) GetValue(PostfixProperty); } + set { SetValue(PostfixProperty, value); } + } + + public static readonly DependencyProperty PostfixProperty = + DependencyProperty.Register("Postfix", typeof(string), typeof(SigmaTextBox), new PropertyMetadata("")); /// /// The fully resolved key to access the synchandler. @@ -86,7 +115,7 @@ public SigmaTextBlock() /// public override void Read() { - Object = SynchronisationHandler.SynchroniseGet(Registry, Key); + SynchronisationHandler.SynchroniseUpdate(Registry, Key, Object, newObj => Dispatcher.Invoke(() => Object = newObj)); } /// diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBox.xaml.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBox.xaml.cs index a1cc1bc2..9dc7c7a4 100644 --- a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBox.xaml.cs +++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTextBox.xaml.cs @@ -15,110 +15,6 @@ For full license see LICENSE in the root directory of this project. namespace Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults { - //[ParameterVisualiser(typeof(float), Priority = VisualiserPriority.Lower)] - //internal class SigmaFloatBox : DynamicConverterBox - //{ } - - //[ParameterVisualiser(typeof(double), Priority = VisualiserPriority.Lower)] - //internal class SigmaDoubleBox : DynamicConverterBox - //{ } - - ///// - ///// Sigmas way of converting a string to given value (e.g. double) - ///// - //public class SigmaConverterBox : SigmaTextBox - //{ - // /// - // /// The converter that converts the given type for the registry - // /// - // protected readonly TypeConverter Converter; - - // public SigmaConverterBox() - // { - // Converter = TypeDescriptor.GetConverter(typeof(T)); - // } - - // /// - // /// Force the visualiser to update its value (i.e. display the value that is stored). - // /// - // public override void Read() - // { - // Text = SynchronisationHandler.SynchroniseGet(Registry, Key).ToString(); - // } - - // /// - // /// Force the visualiser to store its value (i.e. write the value that is displayed to the registry). - // /// - // public override void Write() - // { - // try - // { - // T num = (T) Converter.ConvertFromString(Text); - // Pending = true; - // SynchronisationHandler.SynchroniseSet(Registry, Key, num, val => Pending = false, e => Errored = true); - // } - // catch (Exception) - // { - // Errored = true; - // } - // } - //} - [ParameterVisualiser(typeof(float), Priority = VisualiserPriority.Lower)] - [ParameterVisualiser(typeof(double), Priority = VisualiserPriority.Lower)] - [ParameterVisualiser(typeof(long), Priority = VisualiserPriority.Lower)] - [ParameterVisualiser(typeof(int), Priority = VisualiserPriority.Lower)] - [ParameterVisualiser(typeof(short), Priority = VisualiserPriority.Lower)] - public class DynamicConverterBox : SigmaTextBox - { - /// - /// The converter that converts the given type for the registry - /// - public TypeConverter Converter { get; protected set; } - - public object CurrentValue { get; protected set; } - - /// - /// Force the visualiser to update its value (i.e. display the value that is stored). - /// - public override void Read() - { - object obj = SynchronisationHandler.SynchroniseGet(Registry, Key); - - if (Converter == null && obj != null) - { - Converter = TypeDescriptor.GetConverter(obj.GetType()); - } - - if (obj != null) - { - CurrentValue = obj; - Text = obj.ToString(); - } - } - - /// - /// Force the visualiser to store its value (i.e. write the value that is displayed to the registry). - /// - public override void Write() - { - try - { - object convertedVal = Converter.ConvertFromString(Text); - Pending = true; - SynchronisationHandler.SynchroniseSet(Registry, Key, convertedVal, val => Pending = false, e => Errored = true); - } - catch (Exception) - { - Errored = true; - -#if DEBUG - //TODO: such an ugly hack only for testing - Read(); -#endif - } - } - } - /// /// Sigmas way of displaying strings. /// @@ -151,7 +47,7 @@ public partial class SigmaTextBox public string Text { get { return TextBox.Text; } - set { TextBox.Text = value; } + set { Dispatcher.Invoke(() => TextBox.Text = value); } } /// @@ -181,7 +77,7 @@ public SigmaTextBox() /// public override void Read() { - Text = SynchronisationHandler.SynchroniseGet(Registry, Key); + SynchronisationHandler.SynchroniseUpdate(Registry, Key, Text, newVal => Text = newVal); } /// diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTimeBlock.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTimeBlock.cs new file mode 100644 index 00000000..b2ea855b --- /dev/null +++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/Defaults/SigmaTimeBlock.cs @@ -0,0 +1,60 @@ +using System; +using Sigma.Core.Monitors.WPF.ViewModel.Parameterisation; + +namespace Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults +{ + //TODO: editable time box (timepicker) + //[ParameterVisualiser(typeof(DateTime), Priority = VisualiserPriority.Lower)] + //public class SigmaTimeBox : SigmaTextBox + //{ + + //} + + /// + /// A TimeBlock that allows to display the current time. + /// + [ParameterVisualiser(typeof(DateTime), Priority = VisualiserPriority.Lower)] + public class SigmaTimeBlock : SigmaTextBlock + { + /// + public override object Object + { + get { return _Object; } + set + { + if (_Object is DateTime) + { + _Object = value; + TextBlock.Text = ((DateTime) Object).ToString(FormatString); + } + else + { + base.Object = value; + } + } + } + + /// + /// The string that is used to format the time. + /// null, if default formatting should be applied. + /// + public string FormatString { get; set; } + + /// + /// Create a label that is capable of displaying a time. + /// + public SigmaTimeBlock() + { + base.IsReadOnly = true; + } + + /// + /// Determines whether the parameter is editable or not. + /// + public sealed override bool IsReadOnly + { + get { return base.IsReadOnly; } + set { throw new NotImplementedException(); } + } + } +} \ No newline at end of file diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterView.xaml.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterView.xaml.cs index a04a13df..d00adbcf 100644 --- a/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterView.xaml.cs +++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterView.xaml.cs @@ -8,10 +8,11 @@ For full license see LICENSE in the root directory of this project. using System; using System.Windows; -using System.Windows.Annotations; using System.Windows.Controls; using log4net; using Sigma.Core.Monitors.Synchronisation; +using Sigma.Core.Monitors.WPF.Annotations; +using Sigma.Core.Monitors.WPF.View.Windows; using Sigma.Core.Monitors.WPF.ViewModel.Parameterisation; using Sigma.Core.Utils; @@ -24,11 +25,26 @@ public partial class ParameterView { private readonly ILog _log = LogManager.GetLogger(typeof(ParameterView)); + /// + /// The manager that keeps track of all visualisers. + /// protected readonly IParameterVisualiserManager Manager; + + /// + /// The handler that is used for transactions with the monitor. + /// protected readonly ISynchronisationHandler SynchronisationHandler; + /// + /// The current rowpost to which elements will be added. + /// protected int RowPos; + /// + /// Generate a new parameter view with a given manager and synchronisation handler. + /// + /// A manager of all active visualisers. + /// A handler for parameter syncing. public ParameterView(IParameterVisualiserManager manager, ISynchronisationHandler synchronisationHandler) { Manager = manager; @@ -36,14 +52,31 @@ public ParameterView(IParameterVisualiserManager manager, ISynchronisationHandle InitializeComponent(); } - public void Add(string name, Type type, IRegistry registry, string key) + /// + /// Generate a new parameter view with the manager and synchronisation handler assigned in the environment / window. + /// The currently active environment. + /// The currently active window (i.e. root window). + /// + public ParameterView(SigmaEnvironment environment, SigmaWindow window) : this(window.ParameterVisualiser, environment.SynchronisationHandler) + { + } + + /// + /// Display a given type stored in given registry (with the given key) next to a label with a given text. + /// + /// The text the label will contain. + /// The type that will be displayed. + /// The registry which contains the value that should be displayed. + /// The key to access the exact value required. + public IParameterVisualiser Add(string name, Type type, IRegistry registry, string key) { - Add(new Label { Content = name }, type, registry, key); + return Add(new Label {Content = name}, type, registry, key); } - public void Add(UIElement name, Type visualiserType, IRegistry registry, string key) + public IParameterVisualiser Add(UIElement name, Type visualiserType, IRegistry registry, string key) { - UIElement displayer = (UIElement) Activator.CreateInstance(Manager.VisualiserType(visualiserType)); + //UIElement displayer = (UIElement) Activator.CreateInstance(Manager.VisualiserType(visualiserType)); + UIElement displayer = (UIElement) Manager.InstantiateVisualiser(visualiserType); IParameterVisualiser visualiser = displayer as IParameterVisualiser; if (visualiser == null) @@ -52,11 +85,13 @@ public void Add(UIElement name, Type visualiserType, IRegistry registry, string } Add(name, displayer, visualiser, registry, key); + + return visualiser; } public void Add(string name, object visualiserAndDisplayer, IRegistry registry, string key) { - Add(new Label { Content = name }, visualiserAndDisplayer, registry, key); + Add(new Label {Content = name}, visualiserAndDisplayer, registry, key); } public void Add(UIElement name, object visualiserAndDisplayer, IRegistry registry, string key) @@ -74,16 +109,16 @@ public void Add(UIElement name, object visualiserAndDisplayer, IRegistry registr } /// - /// + /// Add a that contains information (e.g. the name of the object), and display it with a given object (e.g. the object to interact with). /// - /// + /// The element taht displays information about the elment being displayed (e.g. descriptive name). /// The element that displays the object in the cell (normally the same as ). /// The object that is responsible for the link with a variable (normally the same as ). - /// - /// - public void Add(UIElement name, UIElement displayer, IParameterVisualiser visualiser, IRegistry registry, string key) + /// The registry which contains the value that should be displayed. May or may not be null (depending on the visualiser). + /// The key to access the exact value required. May or may not be null (depending on the visualiser). + public void Add([CanBeNull] UIElement name, [CanBeNull] UIElement displayer, [CanBeNull] IParameterVisualiser visualiser, IRegistry registry, string key) { - Content.RowDefinitions.Add(new RowDefinition { Height = GridLength.Auto }); + Content.RowDefinitions.Add(new RowDefinition {Height = GridLength.Auto}); if (visualiser != null) { @@ -93,6 +128,7 @@ public void Add(UIElement name, UIElement displayer, IParameterVisualiser visual visualiser.Read(); } + // add the name to the left if (name != null) { Grid.SetColumn(name, 0); @@ -109,5 +145,8 @@ public void Add(UIElement name, UIElement displayer, IParameterVisualiser visual RowPos++; } + + + } } diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserAttribute.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserAttribute.cs index 5ab287e4..447caf94 100644 --- a/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserAttribute.cs +++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserAttribute.cs @@ -33,10 +33,33 @@ public class ParameterVisualiserAttribute : Attribute, IParameterVisualiserInfo /// /// Define that the class visualises given type. /// - /// + /// The type that is being represented. public ParameterVisualiserAttribute(Type type) { Type = type; } + + /// + /// Determinse whether the given visualiser is generic or not. + /// + public bool IsGeneric { get; protected set; } + } + + /// + /// This marks an . It contains information which type this visualiser implements + /// and reduces the amount of work required to define a new type. Differently from , the class implementing this + /// attribute has to be a generic class with a single attribute, which will be the given type. + /// Multiple attributes can be specified (to display strings and objects for example). + /// + public class GenericParameterVisualiserAttribute : ParameterVisualiserAttribute + { + /// + /// Define that the class visualises given type. + /// + /// The type that is being represented. + public GenericParameterVisualiserAttribute(Type type) : base(type) + { + IsGeneric = true; + } } } \ No newline at end of file diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserInfo.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserInfo.cs index bcbccffc..fdfbb7e5 100644 --- a/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserInfo.cs +++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/ParameterVisualiserInfo.cs @@ -18,25 +18,24 @@ namespace Sigma.Core.Monitors.WPF.View.Parameterisation /// public class ParameterVisualiserInfo : IParameterVisualiserInfo { - /// - /// The type this visualiser visualises. - /// + /// public Type Type { get; } - /// - /// The priority of the . If another priority with a lower priority has already been added, the - /// higher priority will override the settings. - /// + /// public VisualiserPriority Priority { get; set; } + /// + public bool IsGeneric { get; } + /// /// Initializes a new instance of the class. /// /// The type the is responsible for. /// The priority of the info. (higher prioriuty overrides lower ones). + /// Determinse whether the given visualiser is generic or not. /// If bad enum is passed. /// If is null. - public ParameterVisualiserInfo([NotNull] Type type, VisualiserPriority priority = VisualiserPriority.Normal) + public ParameterVisualiserInfo([NotNull] Type type, VisualiserPriority priority = VisualiserPriority.Normal, bool isGeneric = false) { if (type == null) throw new ArgumentNullException(nameof(type)); if (!Enum.IsDefined(typeof(VisualiserPriority), priority)) @@ -46,6 +45,9 @@ public ParameterVisualiserInfo([NotNull] Type type, VisualiserPriority priority Type = type; Priority = priority; + IsGeneric = isGeneric; } + + } } \ No newline at end of file diff --git a/Sigma.Core.Monitors.WPF/View/Parameterisation/UserControlParameterVisualiser.cs b/Sigma.Core.Monitors.WPF/View/Parameterisation/UserControlParameterVisualiser.cs index fef2ae29..4b5f2627 100644 --- a/Sigma.Core.Monitors.WPF/View/Parameterisation/UserControlParameterVisualiser.cs +++ b/Sigma.Core.Monitors.WPF/View/Parameterisation/UserControlParameterVisualiser.cs @@ -6,9 +6,12 @@ MIT License For full license see LICENSE in the root directory of this project. */ +using System; using System.Windows.Controls; using Sigma.Core.Monitors.Synchronisation; using Sigma.Core.Monitors.WPF.ViewModel.Parameterisation; +using Sigma.Core.Training; +using Sigma.Core.Training.Hooks; using Sigma.Core.Utils; namespace Sigma.Core.Monitors.WPF.View.Parameterisation @@ -62,5 +65,28 @@ public abstract class UserControlParameterVisualiser : UserControl, IParameterVi /// Force the visualiser to store its value (i.e. write the value that is displayed to the registry). /// public abstract void Write(); + + /// + /// The currently active poll hook that is responsible for updating values. + /// + protected PollParameterHook ActiveHook; + + /// + /// Enables the automatic polling of values (call Read on every TimeStep). + /// null if no automatic polling should be enabled. + /// + /// The trainer on which the poll will be performed. + /// The TimeStep on when the parameter should update. + public virtual void AutoPollValues(ITrainer trainer, ITimeStep step) + { + if (ActiveHook != null) + { + trainer.Operator.DetachGlobalHook(ActiveHook); + } + + ActiveHook = new PollParameterHook(step, this); + + trainer.AddGlobalHook(ActiveHook); + } } } diff --git a/Sigma.Core.Monitors.WPF/View/Windows/SigmaWindow.cs b/Sigma.Core.Monitors.WPF/View/Windows/SigmaWindow.cs index 1e615dee..2428a1ba 100644 --- a/Sigma.Core.Monitors.WPF/View/Windows/SigmaWindow.cs +++ b/Sigma.Core.Monitors.WPF/View/Windows/SigmaWindow.cs @@ -171,7 +171,7 @@ public class SigmaWindow : WPFWindow, IDisposable /// The that is responsible for creation and detection of /// visualisation elements. /// - public IParameterVisualiserManager ParameterVisualiser { get; } + public IParameterVisualiserManager ParameterVisualiser { get; } /// /// The prefix-identifier for . @@ -663,7 +663,7 @@ protected virtual void AddTabs(TabControlUI tabControl, List { foreach (string name in names) { - tabControl.AddTab(name, new TabUI(name, DefaultGridSize)); + tabControl.AddTab(name, new TabUI(Monitor, name, DefaultGridSize)); } } @@ -675,7 +675,7 @@ public virtual void AddTabs(params string[] tabs) { foreach (string tab in tabs) { - TabControl.AddTab(tab, new TabUI(tab, DefaultGridSize)); + TabControl.AddTab(tab, new TabUI(Monitor, tab, DefaultGridSize)); } } diff --git a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiser.cs b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiser.cs index c1b230bd..1f80063f 100644 --- a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiser.cs +++ b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiser.cs @@ -48,11 +48,13 @@ public interface IParameterVisualiser /// /// Force the visualiser to update its value (i.e. display the value that is stored). + /// This function may be called from an arbitrary thread. /// void Read(); /// /// Force the visualiser to store its value (i.e. write the value that is displayed to the registry). + /// This function may be called from an arbitrary thread. /// void Write(); } diff --git a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserInfo.cs b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserInfo.cs index 9e82d5e6..5772bd64 100644 --- a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserInfo.cs +++ b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserInfo.cs @@ -49,6 +49,13 @@ public interface IParameterVisualiserInfo /// Type Type { get; } + /// + /// Determinse whether the given visualiser is generic or not. + /// + /// More specifically, if the given visualiser should be passed the visualisation type as generic type or not. + /// + bool IsGeneric { get; } + /// /// The priority of the . If another priority with a lower priority has already been added, the /// higher priority will override the settings. diff --git a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserManager.cs b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserManager.cs index 9dd86e73..9ed90060 100644 --- a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserManager.cs +++ b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/IParameterVisualiserManager.cs @@ -8,7 +8,6 @@ For full license see LICENSE in the root directory of this project. using System; using Sigma.Core.Monitors.WPF.Annotations; -using Sigma.Core.Monitors.WPF.View.Parameterisation; using Sigma.Core.Monitors.WPF.View.Parameterisation.Defaults; namespace Sigma.Core.Monitors.WPF.ViewModel.Parameterisation @@ -50,5 +49,11 @@ public interface IParameterVisualiserManager /// The closest type for visualisation. null if not found. Type VisualiserTypeByReference([NotNull] object obj); + /// + /// Instantiate a visualiser that can represent given type. + /// + /// The type that will be visualies + /// An instance of a visualiser. + IParameterVisualiser InstantiateVisualiser(Type type); } } \ No newline at end of file diff --git a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/ParameterVisualiserManager.cs b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/ParameterVisualiserManager.cs index 5cfe5117..13c55e57 100644 --- a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/ParameterVisualiserManager.cs +++ b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/ParameterVisualiserManager.cs @@ -40,11 +40,10 @@ public class ParameterVisualiserManager : IParameterVisualiserManager /// protected readonly Dictionary AttributeMapping; - /// /// The default constructor. /// - /// If true it will automatically add all classes marked with the attribute . + /// If true, it will automatically add all classes marked with the attribute or . public ParameterVisualiserManager(bool autoAssign = true) { TypeMapping = new Dictionary(); @@ -53,25 +52,29 @@ public ParameterVisualiserManager(bool autoAssign = true) if (autoAssign) { // ReSharper disable once VirtualMemberCallInConstructor - AssignMarkedClasses(); + AssignMarkedClasses(typeof(ParameterVisualiserAttribute), typeof(GenericParameterVisualiserAttribute)); } } /// - /// Assign all classes that are marked with . + /// Assign all classes that are marked with the given attributes (marker attributes). + /// These attributes have to be an . /// - protected virtual void AssignMarkedClasses() + protected virtual void AssignMarkedClasses(params Type[] markerTypes) { - // get all classes that have the custom attribute - IEnumerable classes = AttributeUtils.GetTypesWithAttribute(typeof(ParameterVisualiserAttribute)); - - foreach (Type @class in classes) + foreach (Type type in markerTypes) { - ParameterVisualiserAttribute[] visualisers = (ParameterVisualiserAttribute[])Attribute.GetCustomAttributes(@class, typeof(ParameterVisualiserAttribute)); + // get all classes that have the custom attribute + IEnumerable classes = AttributeUtils.GetTypesWithAttribute(type); - foreach (ParameterVisualiserAttribute visualiser in visualisers) + foreach (Type @class in classes) { - Add(@class, visualiser); + IParameterVisualiserInfo[] visualisers = (IParameterVisualiserInfo[]) Attribute.GetCustomAttributes(@class, type); + + foreach (IParameterVisualiserInfo visualiser in visualisers) + { + Add(@class, visualiser); + } } } } @@ -108,15 +111,19 @@ public virtual bool Add(Type visualiserClass, IParameterVisualiserInfo parameter // if the mapping has already been added if (TypeMapping.TryGetValue(parameterInfo.Type, out storedClass) && AttributeMapping.TryGetValue(parameterInfo.Type, out storedAttribte)) { - // if the new values have a lower priority, we return false - if (parameterInfo.Priority <= storedAttribte.Priority) + // if the a differnt type is being represented (necessarry for generics) + if (!ReferenceEquals(visualiserClass, storedClass)) { - _log.Warn($"{parameterInfo.Type} is currently visualised by {storedClass.Name}; {visualiserClass.Name} tried to be the visualiser but has a lower priority ({parameterInfo.Priority} <= {storedAttribte.Priority})."); + // if the new values have a lower priority, we return false + if (parameterInfo.Priority <= storedAttribte.Priority) + { + _log.Warn($"{parameterInfo.Type} is currently visualised by {storedClass.Name}; {visualiserClass.Name} tried to be the visualiser but has a lower priority ({parameterInfo.Priority} <= {storedAttribte.Priority})."); - return false; - } + return false; + } - _log.Info($"{parameterInfo.Type} was visualised by {storedClass.Name}; {visualiserClass.Name} has a higher priority and is therefore the new visualiser ({parameterInfo.Priority} > {storedAttribte.Priority})."); + _log.Debug($"{parameterInfo.Type} was visualised by {storedClass.Name}; {visualiserClass.Name} has a higher priority and is therefore the new visualiser ({parameterInfo.Priority} > {storedAttribte.Priority})."); + } } TypeMapping[parameterInfo.Type] = visualiserClass; @@ -190,5 +197,20 @@ public Type VisualiserTypeByReference(object obj) return VisualiserType(obj.GetType()); } + + + /// + public IParameterVisualiser InstantiateVisualiser(Type type) + { + Type visualiserType = VisualiserType(type); + IParameterVisualiserInfo info = AttributeMapping[type]; + + if (info.IsGeneric) + { + return (IParameterVisualiser) Activator.CreateInstance(visualiserType.MakeGenericType(type)); + } + + return (IParameterVisualiser) Activator.CreateInstance(visualiserType); + } } } \ No newline at end of file diff --git a/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/PollParameterHook.cs b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/PollParameterHook.cs new file mode 100644 index 00000000..d347c7a1 --- /dev/null +++ b/Sigma.Core.Monitors.WPF/ViewModel/Parameterisation/PollParameterHook.cs @@ -0,0 +1,38 @@ +using System; +using Sigma.Core.Training.Hooks; +using Sigma.Core.Utils; + +namespace Sigma.Core.Monitors.WPF.ViewModel.Parameterisation +{ + /// + /// A hook that updates a given IParameterVisualiser on a given TimeStep. + /// + [Serializable] + public class PollParameterHook : BaseHook + { + /// + /// The identifier for the currently active visualiser. + /// + protected const string VisualiserIdentifier = "visualiser"; + + /// + /// Create a hook with a certain time step + /// + /// The time step. + /// The visualisers that will be updated. + public PollParameterHook(ITimeStep timestep, IParameterVisualiser visualiser) : base(timestep) + { + ParameterRegistry[VisualiserIdentifier] = visualiser; + } + + /// + /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied. + /// + /// The registry containing the required values for this hook's execution. + /// A helper resolver for complex registry entries (automatically cached). + public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) + { + ((IParameterVisualiser) ParameterRegistry[VisualiserIdentifier]).Read(); + } + } +} \ No newline at end of file diff --git a/Sigma.Core.Monitors.WPF/ViewModel/Tabs/TabUI.cs b/Sigma.Core.Monitors.WPF/ViewModel/Tabs/TabUI.cs index 9fe107bc..b737ce86 100644 --- a/Sigma.Core.Monitors.WPF/ViewModel/Tabs/TabUI.cs +++ b/Sigma.Core.Monitors.WPF/ViewModel/Tabs/TabUI.cs @@ -17,6 +17,7 @@ For full license see LICENSE in the root directory of this project. using Sigma.Core.Monitors.WPF.Model.UI.Windows; using Sigma.Core.Monitors.WPF.Panels; using Sigma.Core.Monitors.WPF.View; +using Sigma.Core.Monitors.WPF.View.Windows; using Sigma.Core.Monitors.WPF.ViewModel.CustomControls; using WPFGrid = System.Windows.Controls.Grid; @@ -39,17 +40,24 @@ public class TabUI : UIWrapper /// private GridSize _gridSize; + /// + /// The monitor this tab is assigned to. + /// + protected WPFMonitor Monitor; + /// /// Create a new - this basically is a /// with additional control. /// + /// The monitor this tab is assigned to. /// The header of the tab (name in the ) /// - /// The . Use - /// . + /// The . Use + /// . /// - public TabUI(string header, GridSize gridsize) + public TabUI(WPFMonitor monitor, string header, GridSize gridsize) { + Monitor = monitor; Content.Header = header; GridSize = gridsize; } @@ -205,8 +213,10 @@ protected virtual void ApplyLegend(SigmaPanel panel, StatusBarLegendInfo legend) public void AddCumulativePanel(SigmaPanel panel, int rowSpan = 1, int columnSpan = 1, StatusBarLegendInfo legend = null) { + panel.Monitor = Monitor; AddCumulativeElement(panel, rowSpan, columnSpan); ApplyLegend(panel, legend); + panel.Initialise(Monitor.Window); } /// @@ -266,8 +276,10 @@ public void AddCumulativeElement(UIElement element, int rowSpan = 1, int columnS public void AddPanel(SigmaPanel panel, int row, int column, int rowSpan = 1, int columnSpan = 1, StatusBarLegendInfo legend = null) { + panel.Monitor = Monitor; AddElement(panel, row, column, rowSpan, columnSpan); ApplyLegend(panel, legend); + panel.Initialise(Monitor.Window); } /// diff --git a/Sigma.Core/Architecture/INetwork.cs b/Sigma.Core/Architecture/INetwork.cs index 5666ab18..b99ace47 100644 --- a/Sigma.Core/Architecture/INetwork.cs +++ b/Sigma.Core/Architecture/INetwork.cs @@ -34,6 +34,16 @@ public interface INetwork : IDeepCopyable /// IRegistry Registry { get; } + /// + /// The computation handler associated with this network, which is used for initialisation and copy operations. + /// + IComputationHandler AssociatedHandler { get; set; } + + /// + /// Indicate if this network was already initialised. + /// + bool Initialised { get; } + /// /// Validate this network (e.g. ensure all connections are correctly assigned and compatible). /// diff --git a/Sigma.Core/Architecture/Network.cs b/Sigma.Core/Architecture/Network.cs index 1e2b233a..0c04e528 100644 --- a/Sigma.Core/Architecture/Network.cs +++ b/Sigma.Core/Architecture/Network.cs @@ -19,266 +19,284 @@ For full license see LICENSE in the root directory of this project. namespace Sigma.Core.Architecture { - /// - /// A default implementation of the interface. - /// Represents a neural network consisting of interconnected neural layers and a network architecture. - /// - [Serializable] - public class Network : INetwork - { - /// - public INetworkArchitecture Architecture { get; set; } - - /// - public string Name { get; } - - /// - public IRegistry Registry { get; } - - [NonSerialized] - private readonly ILog _logger = LogManager.GetLogger(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType); - private readonly List _orderedLayerBuffers; - private readonly List _externalInputsLayerBuffers; - private readonly List _externalOutputsLayerBuffers; - private List _orderedLayers; - private IComputationHandler _initialisationHandler; - private bool _initialised; - - /// - /// Create a network with a certain unique name. - /// - /// The name. - public Network(string name = "unnamed") - { - if (name == null) throw new ArgumentNullException(nameof(name)); - - Name = name; - Registry = new Registry(tags: "network"); - _orderedLayerBuffers = new List(); - _externalInputsLayerBuffers = new List(); - _externalOutputsLayerBuffers = new List(); - } - - /// - public virtual object DeepCopy() - { - Network copy = new Network(Name); - copy.Architecture = (INetworkArchitecture) Architecture.DeepCopy(); - - if (_initialised) - { - copy.Initialise(_initialisationHandler); - - for (int i = 0; i < _orderedLayerBuffers.Count; i++) - { - InternalLayerBuffer originalBuffer = _orderedLayerBuffers[i]; - InternalLayerBuffer copyBuffer = copy._orderedLayerBuffers[i]; - - foreach (string parameterIdentifier in originalBuffer.Parameters.Keys.ToArray()) - { - object value = originalBuffer.Parameters[parameterIdentifier]; - IDeepCopyable deepCopyableValue = value as IDeepCopyable; - object copiedValue; - - // copy and copy efficiently by any means possible - if (deepCopyableValue == null) - { - ICloneable cloneableValue = value as ICloneable; - copiedValue = cloneableValue?.Clone() ?? value; - } - else - { - INDArray asNDArray = value as INDArray; - - if (asNDArray != null) - { - _initialisationHandler.Fill(asNDArray, copyBuffer.Parameters.Get(parameterIdentifier)); - } - - copiedValue = deepCopyableValue.DeepCopy(); - } - - copyBuffer.Parameters[parameterIdentifier] = copiedValue; - } - } - } - - return copy; - } - - /// - public void Validate() - { - if (Architecture == null) - { - throw new InvalidOperationException("Cannot validate network before assigning a network architecture."); - } - - Architecture.Validate(); - } - - /// - public void Initialise(IComputationHandler handler) - { - if (handler == null) throw new ArgumentNullException(nameof(handler)); - - if (Architecture == null) - { - throw new InvalidOperationException("Cannot initialise network before assigning a network architecture."); - } + /// + /// A default implementation of the interface. + /// Represents a neural network consisting of interconnected neural layers and a network architecture. + /// + [Serializable] + public class Network : INetwork + { + /// + public INetworkArchitecture Architecture { get; set; } + + /// + public string Name { get; } + + /// + public IRegistry Registry { get; } + + /// + /// The computation handler associated with this network, which is used for initialisation and copy operations. + /// Note: Set this + /// + public IComputationHandler AssociatedHandler + { + get { return _associatedHandler; } + set { _associatedHandler = value; } + } + + /// + public bool Initialised { get { return _initialised; } } + + [NonSerialized] + private readonly ILog _logger = LogManager.GetLogger(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType); + private readonly List _orderedLayerBuffers; + private readonly List _externalInputsLayerBuffers; + private readonly List _externalOutputsLayerBuffers; + private List _orderedLayers; + + [NonSerialized] + private IComputationHandler _associatedHandler; + private bool _initialised; + + /// + /// Create a network with a certain unique name. + /// + /// The name. + public Network(string name = "unnamed") + { + if (name == null) throw new ArgumentNullException(nameof(name)); + + Name = name; + Registry = new Registry(tags: "network"); + _orderedLayerBuffers = new List(); + _externalInputsLayerBuffers = new List(); + _externalOutputsLayerBuffers = new List(); + } + + /// + public virtual object DeepCopy() + { + Network copy = new Network(Name); + copy.Architecture = (INetworkArchitecture) Architecture.DeepCopy(); + + if (_initialised) + { + copy.Initialise(_associatedHandler); + + for (int i = 0; i < _orderedLayerBuffers.Count; i++) + { + InternalLayerBuffer originalBuffer = _orderedLayerBuffers[i]; + InternalLayerBuffer copyBuffer = copy._orderedLayerBuffers[i]; + + foreach (string parameterIdentifier in originalBuffer.Parameters.Keys.ToArray()) + { + object value = originalBuffer.Parameters[parameterIdentifier]; + IDeepCopyable deepCopyableValue = value as IDeepCopyable; + object copiedValue; + + // copy and copy efficiently by any means possible + if (deepCopyableValue == null) + { + ICloneable cloneableValue = value as ICloneable; + copiedValue = cloneableValue?.Clone() ?? value; + } + else + { + INDArray asNDArray = value as INDArray; + + if (asNDArray != null) + { + _associatedHandler.Fill(asNDArray, copyBuffer.Parameters.Get(parameterIdentifier)); + copiedValue = copyBuffer.Parameters.Get(parameterIdentifier); + } + else + { + copiedValue = deepCopyableValue.DeepCopy(); + } + } + + copyBuffer.Parameters[parameterIdentifier] = copiedValue; + } + } + } + + return copy; + } + + /// + public void Validate() + { + if (Architecture == null) + { + throw new InvalidOperationException("Cannot validate network before assigning a network architecture."); + } + + Architecture.Validate(); + } + + /// + public void Initialise(IComputationHandler handler) + { + if (handler == null) throw new ArgumentNullException(nameof(handler)); + + if (Architecture == null) + { + throw new InvalidOperationException("Cannot initialise network before assigning a network architecture."); + } + + _logger.Debug($"Initialising network \"{Name}\" for handler {handler} containing {Architecture.LayerCount} layers..."); + + _associatedHandler = handler; + + ITaskObserver prepareTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Prepare); + + Architecture.ResolveAllNames(); + + _orderedLayerBuffers.Clear(); + _externalInputsLayerBuffers.Clear(); + _externalOutputsLayerBuffers.Clear(); + + Dictionary, IRegistry> mappedRegistriesByInOutputs = new Dictionary, IRegistry>(); + + foreach (LayerConstruct layerConstruct in Architecture.YieldLayerConstructsOrdered()) + { + ILayer layer = layerConstruct.InstantiateLayer(handler); + + Dictionary inputs = new Dictionary(); + + foreach (string externalInputAlias in layerConstruct.ExternalInputs) + { + inputs[externalInputAlias] = new Registry(tags: "external_input"); + } - _logger.Debug($"Initialising network \"{Name}\" for handler {handler} containing {Architecture.LayerCount} layers..."); + foreach (string inputAlias in layerConstruct.Inputs.Keys) + { + inputs[inputAlias] = mappedRegistriesByInOutputs[new Tuple(layerConstruct.Inputs[inputAlias], layerConstruct)]; + } - _initialisationHandler = handler; - - ITaskObserver prepareTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Prepare); - - Architecture.ResolveAllNames(); - - _orderedLayerBuffers.Clear(); - _externalInputsLayerBuffers.Clear(); - _externalOutputsLayerBuffers.Clear(); - - Dictionary, IRegistry> mappedRegistriesByInOutputs = new Dictionary, IRegistry>(); - - foreach (LayerConstruct layerConstruct in Architecture.YieldLayerConstructsOrdered()) - { - ILayer layer = layerConstruct.InstantiateLayer(handler); - - Dictionary inputs = new Dictionary(); + Dictionary outputs = new Dictionary(); - foreach (string externalInputAlias in layerConstruct.ExternalInputs) - { - inputs[externalInputAlias] = new Registry(tags: "external_input"); - } + foreach (string externalOutputAlias in layerConstruct.ExternalOutputs) + { + outputs[externalOutputAlias] = new Registry(tags: "external_output"); + } - foreach (string inputAlias in layerConstruct.Inputs.Keys) - { - inputs[inputAlias] = mappedRegistriesByInOutputs[new Tuple(layerConstruct.Inputs[inputAlias], layerConstruct)]; - } + foreach (string outputAlias in layerConstruct.Outputs.Keys) + { + LayerConstruct outputConstruct = layerConstruct.Outputs[outputAlias]; + + Tuple inOuTuple = new Tuple(layerConstruct, outputConstruct); - Dictionary outputs = new Dictionary(); + Registry outRegistry = new Registry(tags: "internal"); + + mappedRegistriesByInOutputs.Add(inOuTuple, outRegistry); + + outputs[outputAlias] = outRegistry; + } - foreach (string externalOutputAlias in layerConstruct.ExternalOutputs) - { - outputs[externalOutputAlias] = new Registry(tags: "external_output"); - } + InternalLayerBuffer layerBuffer = new InternalLayerBuffer(layer, layerConstruct.Parameters, inputs, outputs, + layerConstruct.ExternalInputs, layerConstruct.ExternalOutputs); + + _orderedLayerBuffers.Add(layerBuffer); + + if (layerConstruct.ExternalInputs.Length > 0) + { + _externalInputsLayerBuffers.Add(layerBuffer); + } - foreach (string outputAlias in layerConstruct.Outputs.Keys) - { - LayerConstruct outputConstruct = layerConstruct.Outputs[outputAlias]; + if (layerConstruct.ExternalOutputs.Length > 0) + { + _externalOutputsLayerBuffers.Add(layerBuffer); + } + } - Tuple inOuTuple = new Tuple(layerConstruct, outputConstruct); + _orderedLayers = _orderedLayerBuffers.ConvertAll(buffer => buffer.Layer); - Registry outRegistry = new Registry(tags: "internal"); + UpdateRegistry(); - mappedRegistriesByInOutputs.Add(inOuTuple, outRegistry); + SigmaEnvironment.TaskManager.EndTask(prepareTask); - outputs[outputAlias] = outRegistry; - } + _initialised = true; - InternalLayerBuffer layerBuffer = new InternalLayerBuffer(layer, layerConstruct.Parameters, inputs, outputs, - layerConstruct.ExternalInputs, layerConstruct.ExternalOutputs); + _logger.Debug($"Done initialising network \"{Name}\" for handler {handler} containing {Architecture.LayerCount} layers."); + } - _orderedLayerBuffers.Add(layerBuffer); + protected virtual void UpdateRegistry() + { + Registry.Clear(); - if (layerConstruct.ExternalInputs.Length > 0) - { - _externalInputsLayerBuffers.Add(layerBuffer); - } + Registry["initialised"] = _initialised; + Registry["self"] = this; + Registry["name"] = Name; + Registry["architecture"] = Architecture?.Registry; - if (layerConstruct.ExternalOutputs.Length > 0) - { - _externalOutputsLayerBuffers.Add(layerBuffer); - } - } + Registry layersRegistry = new Registry(Registry); + Registry["layers"] = layersRegistry; - _orderedLayers = _orderedLayerBuffers.ConvertAll(buffer => buffer.Layer); + foreach (InternalLayerBuffer layerBuffer in _orderedLayerBuffers) + { + layersRegistry[layerBuffer.Layer.Name] = layerBuffer.Layer.Parameters; + } + } - UpdateRegistry(); + /// + public void Run(IComputationHandler handler, bool trainingPass) + { + if (handler == null) throw new ArgumentNullException(nameof(handler)); - SigmaEnvironment.TaskManager.EndTask(prepareTask); + foreach (InternalLayerBuffer layerBuffer in _orderedLayerBuffers) + { + layerBuffer.Layer.Run(layerBuffer, handler, trainingPass); + } + } - _initialised = true; + /// + public void Reset() + { + _logger.Debug($"Resetting network \"{Name}\" to un-initialised state..."); - _logger.Debug($"Done initialising network \"{Name}\" for handler {handler} containing {Architecture.LayerCount} layers."); - } + _orderedLayerBuffers.Clear(); + _orderedLayers.Clear(); + _externalInputsLayerBuffers.Clear(); + _externalOutputsLayerBuffers.Clear(); - protected virtual void UpdateRegistry() - { - Registry.Clear(); + _initialised = false; + _associatedHandler = null; - Registry["initialised"] = _initialised; - Registry["self"] = this; - Registry["name"] = Name; - Registry["architecture"] = Architecture?.Registry; + UpdateRegistry(); - Registry layersRegistry = new Registry(Registry); - Registry["layers"] = layersRegistry; + _logger.Debug($"Done resetting network \"{Name}\". All layer buffer information was discarded."); + } - foreach (InternalLayerBuffer layerBuffer in _orderedLayerBuffers) - { - layersRegistry[layerBuffer.Layer.Name] = layerBuffer.Layer.Parameters; - } - } + /// + public IEnumerable YieldLayersOrdered() + { + return _orderedLayers; + } - /// - public void Run(IComputationHandler handler, bool trainingPass) - { - if (handler == null) throw new ArgumentNullException(nameof(handler)); + /// + public IEnumerable YieldLayerBuffersOrdered() + { + return _orderedLayerBuffers; + } - foreach (InternalLayerBuffer layerBuffer in _orderedLayerBuffers) - { - layerBuffer.Layer.Run(layerBuffer, handler, trainingPass); - } - } + /// + public IEnumerable YieldExternalInputsLayerBuffers() + { + return _externalInputsLayerBuffers; + } - /// - public void Reset() - { - _logger.Debug($"Resetting network \"{Name}\" to un-initialised state..."); - - _orderedLayerBuffers.Clear(); - _orderedLayers.Clear(); - _externalInputsLayerBuffers.Clear(); - _externalOutputsLayerBuffers.Clear(); - - _initialised = false; - _initialisationHandler = null; - - UpdateRegistry(); - - _logger.Debug($"Done resetting network \"{Name}\". All layer buffer information was discarded."); - } - - /// - public IEnumerable YieldLayersOrdered() - { - return _orderedLayers; - } - - /// - public IEnumerable YieldLayerBuffersOrdered() - { - return _orderedLayerBuffers; - } - - /// - public IEnumerable YieldExternalInputsLayerBuffers() - { - return _externalInputsLayerBuffers; - } - - /// - public IEnumerable YieldExternalOutputsLayerBuffers() - { - return _externalOutputsLayerBuffers; - } - - /// - public INetworkSelector Select() - { - return new DefaultNetworkSelector(this); - } - } + /// + public IEnumerable YieldExternalOutputsLayerBuffers() + { + return _externalOutputsLayerBuffers; + } + + /// + public INetworkSelector Select() + { + return new DefaultNetworkSelector(this); + } + } } diff --git a/Sigma.Core/Data/Datasets/Dataset.cs b/Sigma.Core/Data/Datasets/Dataset.cs deleted file mode 100644 index d84396f4..00000000 --- a/Sigma.Core/Data/Datasets/Dataset.cs +++ /dev/null @@ -1,1048 +0,0 @@ -/* -MIT License - -Copyright (c) 2016-2017 Florian Cäsar, Michael Plainer - -For full license see LICENSE in the root directory of this project. -*/ - -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using log4net; -using Sigma.Core.Data.Extractors; -using Sigma.Core.Handlers; -using Sigma.Core.MathAbstract; -using Sigma.Core.Persistence; -using Sigma.Core.Utils; - -namespace Sigma.Core.Data.Datasets -{ - /// - /// A default implementation of the IDataset interface. - /// Provides caching of entire blocks and reader data, partial extraction, unordered extraction, automatic block sizing, smart block loading. - /// - [Serializable] - public class Dataset : IDataset, ISerialisationNotifier - { - [NonSerialized] - private readonly ILog _logger = LogManager.GetLogger(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType); - - /// - /// Automatically size blocks according to estimated data metrics (e.g. physical memory available, record size). - /// - public const int BlockSizeAuto = -1; - - /// - /// Assign all available data to the first block (one block fits it all - literally). - /// - public const int BlockSizeAll = -2; - - /// - public string Name { get; } - - /// - /// Indicate if this dataset is an online dataset (meaning new data might be added during runtime). - /// By default, this is assumed to be false, indicating a static dataset. - /// Note: Data iterators and may perform certain optimisations for static datasets, so set this to false if possible. - /// - public bool Online { get; set; } = false; - - /// - public int MaxConcurrentActiveBlocks { get; } = 24; //24 seems like a good number, right? - - /// - public long MaxTotalActiveBlockSizeBytes { get; } = SystemInformationUtils.GetAvailablePhysicalMemoryBytes() / 2; //default to half the available physical memory - - /// - public IReadOnlyCollection ActiveBlockIndices => _activeBlocks.Keys.ToList(); - - /// - public int ActiveBlockRegionCount => _activeBlocks.Count; - - /// - public int ActiveIndividualBlockCount { get { return _activeBlocks.Values.Sum(set => set.Count); } } - - /// - public int TargetBlockSizeRecords { get; private set; } - - /// - public string[] SectionNames { get; private set; } - - /// - public long TotalActiveBlockSizeBytes { get; private set; } - - public long TotalActiveRecords { get; private set; } - - /// - public int MaxBlocksInCache { get; set; } = int.MaxValue; - - /// - public long MaxBytesInCache { get; set; } = long.MaxValue; - - /// - /// Indicate whether this dataset should cache the raw reader data. - /// If disabled, only extracted data will be cached and once processed, it might be impossible to retrieve preceding record blocks (reader streams are assumed to be non-seekable). - /// - public bool AllowRawReadDataCaching { get; set; } = true; - - private readonly Dictionary> _activeBlocks; - private readonly Dictionary> _cachedBlocks; - private readonly ICacheProvider _cacheProvider; - - private int _lastReadRawDataBlockIndex = -1; - private long _totalCachedBlockSizeBytes; - private int _lastAvailableBlockIndex = int.MaxValue; - private readonly ISet _recordExtractors; - - private readonly bool _autoSetBlockSize; - private bool _autoSetExternalChangeBlockSize; - - // TODO fix available blocks semaphore logic - // the waitones/releases are inconsistent, because blocks aren't always actually allocated, such as null returns are not considered - [NonSerialized] - private Semaphore _availableBlocksSemaphore; - private int _availableBlocksSemaphoreState; - - /// - /// Create a dataset with a certain unique name and the record extractors to use. - /// - /// The unique dataset name. - /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks. - public Dataset(string name, params IRecordExtractor[] recordExtractors) : this(name, BlockSizeAuto, recordExtractors) - { - } - - /// - /// Create a dataset with a certain unique name, target block size in records and the record extractors to use. - /// - /// The unique dataset name. - /// The target block size for records. May also be or . - /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks. - public Dataset(string name, int blockSizeRecords, params IRecordExtractor[] recordExtractors) - : this(name, blockSizeRecords, new DiskCacheProvider(SigmaEnvironment.Globals.Get("cache_path") + name), true, recordExtractors) - { - } - - /// - /// Create a dataset with a certain unique name, target block size in records, specific cache provider and the record extractors to use. - /// - /// The unique dataset name. - /// The target block size for records. May also be or . - /// The cache provider to use for caching record blocks and raw reader data. - /// Indicate whether the cache provider should be flushed (cleared) before use. Only disable if block size and extractors used do not change (otherwise undefined behaviour). - /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks. - public Dataset(string name, int blockSizeRecords, ICacheProvider cacheProvider, bool flushCache = true, params IRecordExtractor[] recordExtractors) - { - if (name == null) - { - throw new ArgumentNullException(nameof(name)); - } - - if (recordExtractors == null) - { - throw new ArgumentNullException(nameof(recordExtractors)); - } - - if (recordExtractors.Length == 0) - { - throw new ArgumentException("Datasets require at least one record extractor, but none were given."); - } - - if (cacheProvider == null) - { - throw new ArgumentNullException(nameof(cacheProvider)); - } - - switch (blockSizeRecords) { - case BlockSizeAll: - //just set to maximum amount of records, extracting returns the maximum available anyway and we can't know the actual availability yet - TargetBlockSizeRecords = int.MaxValue; - break; - case BlockSizeAuto: - //somewhat temporary guesstimate, should probably expose the individual parameters - const long estimatedRecordSizeBytes = 1024; - const double memoryToConsume = 0.2f; - const long optimalNumberBlocks = 8; - const int maxBlockSizeRecords = 4096; - long availableSystemMemory = SystemInformationUtils.GetAvailablePhysicalMemoryBytes(); - - TargetBlockSizeRecords = Math.Min(maxBlockSizeRecords, (int) (availableSystemMemory * memoryToConsume / estimatedRecordSizeBytes / optimalNumberBlocks)); - - _autoSetBlockSize = true; - break; - default: - if (blockSizeRecords == 0 || blockSizeRecords < -2) - { - throw new ArgumentException($"Block size in records must be either BLOCK_SIZE_ALL, BLOCK_SIZE_AUTO or > 0, but given block size was {blockSizeRecords}."); - } - else - { - TargetBlockSizeRecords = blockSizeRecords; - } - break; - } - - Name = name; - AnalyseExtractors(recordExtractors); - - _cacheProvider = cacheProvider; - _recordExtractors = new HashSet(recordExtractors); - - _availableBlocksSemaphore = new Semaphore(MaxConcurrentActiveBlocks, MaxConcurrentActiveBlocks); - _availableBlocksSemaphoreState = MaxConcurrentActiveBlocks; - - _activeBlocks = new Dictionary>(); - _cachedBlocks = new Dictionary>(); - - if (flushCache) - { - _logger.Debug($"Flushing all caches for dataset \"{Name}\" as flushCache flag was set..."); - - InvalidateAndClearCaches(); - - _logger.Debug($"Done flushing all caches for dataset \"{Name}.\""); - } - } - - /// - /// Called before this object is serialised. - /// - public void OnSerialising() - { - } - - /// - /// Called after this object was serialised. - /// - public void OnSerialised() - { - } - - /// - /// Called after this object was de-serialised. - /// - public void OnDeserialised() - { - InvalidateAndClearCaches(); - _availableBlocksSemaphore = new Semaphore(MaxConcurrentActiveBlocks - ActiveIndividualBlockCount, MaxConcurrentActiveBlocks); - } - - public IDataset[] SplitBlockwise(params int[] parts) - { - return SplitBlockwise(this, parts); - } - - public IDataset[] SplitRecordwise(params double[] parts) - { - return SplitRecordwise(this, parts); - } - - public bool TrySetBlockSize(int blockSizeRecords) - { - if (blockSizeRecords == TargetBlockSizeRecords) - { - //nothing to do here - return true; - } - - if (!_autoSetBlockSize) - { - _logger.Debug($"Cannot change block size as block size was not set automatically (attempted to change block size to {blockSizeRecords}."); - - return false; - } - - if (_activeBlocks.Count > 0 || _cachedBlocks.Count > 0) - { - _logger.Debug($"Cannot change block size as {_activeBlocks.Count + _cachedBlocks.Count} blocks were already fetched and are active or cached."); - - return false; - } - - if (_autoSetExternalChangeBlockSize && blockSizeRecords != TargetBlockSizeRecords) - { - _logger.Debug($"Cannot change block size to {blockSizeRecords}, block size is incompatible with another external block size change request (other request: {TargetBlockSizeRecords})"); - - return false; - } - - _autoSetExternalChangeBlockSize = true; - TargetBlockSizeRecords = blockSizeRecords; - - return true; - } - - private void AnalyseExtractors(IEnumerable extractors) - { - ISet sectionNames = new HashSet(); - - int index = 0; - foreach (IRecordExtractor extractor in extractors) - { - if (extractor == null) - { - throw new ArgumentNullException($"Extractor at index {index} was null."); - } - - if (extractor.SectionNames == null) - { - throw new ArgumentNullException($"Section names field in extractor {extractor} was null (field has to be set by extractor)."); - } - - string[] extractorSectionNames = extractor.SectionNames; - - foreach (string sectionName in extractorSectionNames) - { - if (sectionNames.Contains(sectionName)) - { - throw new ArgumentException($"Section name collision: duplicate section name {sectionName} detected for extractor {extractor}."); - } - else - { - sectionNames.Add(sectionName); - } - } - - index++; - } - - SectionNames = sectionNames.ToArray(); - } - - public int GetNumberOfLoadedInactiveCachedBlocks() - { - return _cachedBlocks.Values.SelectMany(blockSet => blockSet).Count(block => block.Loaded); - } - - public bool CanFetchBlocksAfter(int blockIndex) - { - return blockIndex <= _lastAvailableBlockIndex; - } - - public async Task> FetchBlockAsync(int blockIndex, IComputationHandler handler, bool shouldWaitUntilAvailable = true) - { - //TODO check if block even could be fetched to not waste thread resources if shouldWaitUntilAvailable is false anyway - - return await Task.Run(() => FetchBlock(blockIndex, handler, shouldWaitUntilAvailable)); - } - - public IDictionary FetchBlock(int blockIndex, IComputationHandler handler, bool shouldWaitUntilAvailable = true) - { - Dictionary block = FetchBlockConstrained(blockIndex, handler); - - //block could be fetched directly without violating any constraints, return successfully - if (block != null) - { - if (block.Count == 0) - { - throw new InvalidOperationException("Fetched block did not contain any named elements (was empty; is the extractor output correct?)."); - } - - RegisterActiveBlock(block, blockIndex, handler); - - return block; - } - else - { - if (blockIndex >= _lastAvailableBlockIndex) - { - return null; - } - - if (shouldWaitUntilAvailable) - { - _logger.Debug($"Could not directly load block with index {blockIndex} for handler {handler} and shouldWaitUntilAvailable flag is set to true, waiting for available space..."); - - return FetchBlockWhenAvailable(blockIndex, handler); - } - else - { - return null; - } - } - } - - private void RegisterActiveBlock(Dictionary block, int blockIndex, IComputationHandler handler) - { - INDArray firstNamedBlock = block[block.First().Key]; - - if (IsBlockActive(blockIndex, handler)) - { - //block already registered as active, nothing to do here - return; - } - - RecordBlock recordBlock = new RecordBlock(block, blockIndex, firstNamedBlock.Shape[0], - handler.GetSizeBytes(block.Values.ToArray()), handler) - { Loaded = true, Active = true }; - - lock (this) - { - TotalActiveBlockSizeBytes += recordBlock.EstimatedSizeBytes; - TotalActiveRecords += recordBlock.NumberRecords; - - if (!_activeBlocks.ContainsKey(blockIndex)) - { - _activeBlocks.Add(blockIndex, new HashSet()); - } - - _activeBlocks[blockIndex].Add(recordBlock); - } - } - - private void DeregisterActiveBlock(RecordBlock recordBlock) - { - if (!IsBlockActive(recordBlock.BlockIndex, recordBlock.Handler)) - { - //block that should be de-registered is not even registered - return; - } - - lock (this) - { - TotalActiveBlockSizeBytes -= recordBlock.EstimatedSizeBytes; - TotalActiveRecords -= recordBlock.NumberRecords; - - _activeBlocks[recordBlock.BlockIndex].Remove(recordBlock); - - if (_activeBlocks[recordBlock.BlockIndex].Count == 0) - { - _activeBlocks.Remove(recordBlock.BlockIndex); - } - } - } - - private void RegisterCachedBlock(Dictionary block, int blockIndex, IComputationHandler handler, bool keepReference) - { - if (IsBlockCached(blockIndex, handler)) - { - //block's already cached, nothing to do here - return; - } - - if (!_cachedBlocks.ContainsKey(blockIndex)) - { - _cachedBlocks.Add(blockIndex, new HashSet()); - } - - WeakRecordBlock recordBlock = new WeakRecordBlock(keepReference ? block : null, blockIndex, block.First().Value.Shape[0], handler.GetSizeBytes(block.Values.ToArray()), handler); - - recordBlock.Loaded = false; - - _cachedBlocks[blockIndex].Add(recordBlock); - } - - /// - /// Invalidate and clear all caches associated with this dataset. - /// WARNING: Removing cache entries may cause certain datasets to load much more slowly or even incorrectly. - /// Legitimate use cases include removing cache entries for old datasets or changing extractors. - /// - public void InvalidateAndClearCaches() - { - _logger.Debug("Invalidating and clearing all caches..."); - - _cacheProvider.RemoveAll(); - _cachedBlocks.Clear(); - _totalCachedBlockSizeBytes = 0L; - - _logger.Debug("Done invalidating and clearing all caches."); - } - - private Dictionary FetchBlockWhenAvailable(int blockIndex, IComputationHandler handler) - { - while (true) - { - _logger.Debug($"Attempting to extract block region for request for block index {blockIndex} for handler {handler}, checking if it fits all constraints..."); - - Dictionary block = FetchBlockConstrained(blockIndex, handler); - - //if block != null we could fetch the block successfully without violating any constraints - if (block != null) - { - RegisterActiveBlock(block, blockIndex, handler); - - return block; - } - else - { - //we cannot retrieve any more blocks and shouldn't keep trying - if (blockIndex >= _lastAvailableBlockIndex) - { - return null; - } - - _logger.Debug($"Request for block with index {blockIndex} for handler {handler} was returned to the queue, seems to be violating constraints..."); - } - } - } - - private Dictionary FetchBlockConstrained(int blockIndex, IComputationHandler handler) - { - if (ActiveIndividualBlockCount >= MaxConcurrentActiveBlocks) - { - _logger.Debug($"Unable to fetch block due to MaxConcurrentActiveBlocks constraint of {MaxConcurrentActiveBlocks}."); - - return null; - } - - Dictionary block = LoadAndExtractBlockWhenAvailable(blockIndex, handler); - - //there was nothing to load and extract, most likely end of stream - if (block == null) - { - return null; - } - - long blockSizeBytes = handler.GetSizeBytes(block.Values.ToArray()); - - if (TotalActiveBlockSizeBytes + blockSizeBytes > MaxTotalActiveBlockSizeBytes) - { - _logger.Debug($"Unable to keep requested block {blockIndex} for handler {handler} in memory due to MaxTotalActiveBlockSizeBytes constraint of {MaxTotalActiveBlockSizeBytes} bytes (block of size {blockSizeBytes} would exceed constraint by {TotalActiveBlockSizeBytes + blockSizeBytes - MaxTotalActiveBlockSizeBytes} bytes.)."); - - CacheBlockConstrained(block, blockIndex, handler); - - return null; - } - - return block; - } - - private Dictionary LoadAndExtractBlockWhenAvailable(int blockIndex, IComputationHandler handler) - { - //this method takes care of - // - checking whether the index is already loaded and active and then converts it - // - or checking whether the index is already cached in the right format and loads - // - or if none of that, loads and extracts from the original extractors - - //check whether a block with the same index and format is already active - if (_activeBlocks.ContainsKey(blockIndex)) - { - Dictionary block = GetBestMatchedBlockWhenAvailable(_activeBlocks[blockIndex], handler); - - if (block != null) - { - return block; - } - } - - //check whether a block with the same index and format is already loaded and cached but not active - if (_cachedBlocks.ContainsKey(blockIndex)) - { - Dictionary block = GetBestMatchedBlockWhenAvailable(_cachedBlocks[blockIndex], handler); - - if (block != null) - { - return block; - } - } - - lock (_cacheProvider) - { - string blockIdentifierInCache = $"extracted.{blockIndex}.{handler.DataType.Identifier}"; - - //check whether a block of the same index and format is cached in the cache provider - if (_cacheProvider.IsCached(blockIdentifierInCache)) - { - Dictionary block = _cacheProvider.Load>(blockIdentifierInCache); - - //if its != null we could read it correctly in the right format - if (block != null) - { - //register this cache entry as a properly loaded block in case the cache wasn't flushed and the cache map is outdated - RegisterCachedBlock(block, blockIndex, handler, keepReference: false); - - return block; - } - } - } - - //_availableBlocksSemaphore.WaitOne(); - //_availableBlocksSemaphoreState--; - - return LoadAndExtractRaw(blockIndex, handler); - } - - private Dictionary GetBestMatchedBlockWhenAvailable(IEnumerable blocks, IComputationHandler handler) - { - RecordBlockBase bestMatchedBlock = null; - - foreach (RecordBlockBase otherBlock in blocks) - { - if (otherBlock.Loaded && handler.CanConvert(otherBlock.FirstNamedBlock, otherBlock.Handler)) - { - if (handler.IsInterchangeable(otherBlock.Handler)) - { - //no need to look any further, we already found the perfect match and can return without conversion - return otherBlock.NamedBlockSections; - } - - bestMatchedBlock = otherBlock; - } - } - - if (bestMatchedBlock == null) - { - return null; - } - - //_availableBlocksSemaphore.WaitOne(); - //_availableBlocksSemaphoreState--; - - return ConvertNamedBlocks(bestMatchedBlock.NamedBlockSections, handler); - } - - private static Dictionary ConvertNamedBlocks(Dictionary namedBlockSections, IComputationHandler handler) - { - Dictionary convertedNamedBlocks = new Dictionary(); - - foreach (string name in namedBlockSections.Keys) - { - convertedNamedBlocks.Add(name, handler.Convert(namedBlockSections[name], handler)); - } - - return convertedNamedBlocks; - } - - private Dictionary LoadAndExtractRaw(int blockIndex, IComputationHandler handler) - { - // this cannot run concurrently as cache entries can only be read and written once without wasting resources and / or corrupting cache state - lock (this) - { - if (blockIndex >= _lastReadRawDataBlockIndex) - { - object[] lastRawData = null; - - for (int tempBlockIndex = _lastReadRawDataBlockIndex + 1; tempBlockIndex <= blockIndex; tempBlockIndex++) - { - lastRawData = LoadDirect(tempBlockIndex, handler); - - //looks like we couldn't read any more blocks, maybe reached the end of the underlying source streams - if (lastRawData == null) - { - return null; - } - - if (AllowRawReadDataCaching) - { - _cacheProvider.Store($"raw.{tempBlockIndex}", lastRawData); - } - } - - return ExtractDirectFrom(lastRawData, blockIndex, handler); - } - else - { - if (AllowRawReadDataCaching) - { - string cacheIdentifier = $"raw.{blockIndex}"; - - if (!_cacheProvider.IsCached(cacheIdentifier)) - { - throw new InvalidOperationException($"Unable to load cached entry for block {blockIndex} for handler {handler}, cache entry does not exist in provider {_cacheProvider}."); - } - - return ExtractDirectFrom(_cacheProvider.Load(cacheIdentifier), blockIndex, handler); - } - else - { - throw new InvalidOperationException($"Cannot load and extract raw block with index {blockIndex} because AllowRawReadDataCaching is set to false and last read position is at {_lastReadRawDataBlockIndex}."); - } - } - } - } - - private object[] LoadDirect(int blockIndex, IComputationHandler handler) - { - IList rawDataPerExtractor = new List(); - - PrepareExtractors(); - - foreach (IRecordExtractor extractor in _recordExtractors) - { - object data; - - lock (extractor.Reader) - { - data = extractor.Reader.Read(TargetBlockSizeRecords); - } - - //check if block reader could read anything, if not, return null - if (data == null) - { - _lastAvailableBlockIndex = blockIndex - 1; - - _logger.Debug($"Cannot load block {blockIndex} for handler {handler}, the underlying stream for extractor {extractor} is unable to retrieve any more records. End of stream most likely reached."); - - return null; - } - - rawDataPerExtractor.Add(data); - } - - if (blockIndex > _lastReadRawDataBlockIndex) - { - _lastReadRawDataBlockIndex = blockIndex; - } - - return rawDataPerExtractor.ToArray(); - } - - private Dictionary ExtractDirectFrom(object[] data, int blockIndex, IComputationHandler handler) - { - Dictionary namedBlocks = new Dictionary(); - - ITaskObserver prepareTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Prepare, "preparing extractors for dataset \"" + Name + "\"", indeterminate: true); - - PrepareExtractors(); - - SigmaEnvironment.TaskManager.EndTask(prepareTask); - - ITaskObserver extractTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Extract, $"extracting block {blockIndex} for dataset \"{Name}\"", indeterminate: true); - - int extractorIndex = 0; - foreach (IRecordExtractor extractor in _recordExtractors) - { - _logger.Debug($"Extracting hierarchically from extractor {extractor} at index {extractorIndex}..."); - - Dictionary subNamedBlock = extractor.ExtractHierarchicalFrom(data[extractorIndex++], TargetBlockSizeRecords, handler); - - //check if block size is 0, indicating we reached the end of the stream - if (subNamedBlock == null) - { - _lastAvailableBlockIndex = blockIndex - 1; - - _logger.Debug($"Cannot extract block {blockIndex} for handler {handler}, the underlying stream for extractor {extractor} is unable to retrieve any more records. End of stream most likely reached."); - - SigmaEnvironment.TaskManager.CancelTask(extractTask); - - return null; - } - - foreach (string name in subNamedBlock.Keys) - { - if (namedBlocks.ContainsKey(name)) - { - SigmaEnvironment.TaskManager.CancelTask(extractTask); - - throw new ArgumentException($"Section name collision: {name} is already used by another extractor, current extractor {extractor} cannot use it again."); - } - else - { - namedBlocks.Add(name, subNamedBlock[name]); - } - } - } - - SigmaEnvironment.TaskManager.EndTask(extractTask); - - return namedBlocks; - } - - public void FreeBlock(int blockIndex, IComputationHandler handler) - { - if (!_activeBlocks.ContainsKey(blockIndex)) - { - _logger.Debug($"Unable to free block with index {blockIndex} for handler {handler} because no block with that information is currently active."); - - return; - } - - foreach (RecordBlock block in _activeBlocks[blockIndex]) - { - if (ReferenceEquals(block.Handler, handler)) - { - _logger.Debug($"Freeing block with index {blockIndex} for handler {handler}..."); - - CacheBlockConstrained(block.NamedBlockSections, blockIndex, handler); - - DeregisterActiveBlock(block); - - //_availableBlocksSemaphore.Release(); - //_availableBlocksSemaphoreState++; - - _logger.Debug($"Done freeing block with index {blockIndex} for handler {handler}."); - - return; - } - } - - _logger.Debug($"Unable to free block with index {blockIndex} for handler {handler} because no block with that information is currently active."); - } - - private void CacheBlockConstrained(Dictionary block, int blockIndex, IComputationHandler handler) - { - if (_cachedBlocks.ContainsKey(blockIndex)) - { - foreach (WeakRecordBlock cachedBlock in _cachedBlocks[blockIndex]) - { - //check if block of the same type and size is already cached, if so, return, because there is no need to cache again - if (cachedBlock.BlockIndex == blockIndex && cachedBlock.Handler.IsInterchangeable(handler) && block.First().Value.Shape[0] == cachedBlock.NumberRecords) - { - _logger.Debug($"Skipping cache request of block {blockIndex} for handler {handler} because interchangeable block of same index, format and size is already cached."); - - return; - } - } - } - - long blockSizeBytes = handler.GetSizeBytes(block.Values.ToArray()); - - if (_cachedBlocks.Count >= MaxBlocksInCache) - { - _logger.Debug($"Unable to cache block {blockIndex} for handler {handler} due to MaxBlocksInCache constraint of {MaxBlocksInCache}."); - - return; - } - - if (blockSizeBytes + _totalCachedBlockSizeBytes >= MaxBytesInCache) - { - _logger.Debug($"Unable to cache block {blockIndex} for handler {handler} due to MaxBytesInCache constraint of {MaxBytesInCache} bytes (block of size {blockSizeBytes} would exceed constraint by {_totalCachedBlockSizeBytes + blockSizeBytes - MaxBytesInCache} bytes)."); - - return; - } - - string cacheIdentifier = $"extracted.{blockIndex}.{handler.DataType.Identifier}"; - - _cacheProvider.Store(cacheIdentifier, block); - - bool keepReference = TotalActiveBlockSizeBytes + blockSizeBytes < MaxTotalActiveBlockSizeBytes; - - RegisterCachedBlock(block, blockIndex, handler, keepReference); - - _totalCachedBlockSizeBytes += blockSizeBytes; - } - - private void PrepareExtractors() - { - foreach (IRecordExtractor extractor in _recordExtractors) - { - lock (extractor) - { - extractor.Prepare(); - } - } - } - - public long GetBlockSizeBytes(int blockIndex, IComputationHandler handler) - { - if (!_activeBlocks.ContainsKey(blockIndex)) - { - return -1L; - } - - foreach (RecordBlock block in _activeBlocks[blockIndex]) - { - if (ReferenceEquals(block.Handler, handler)) - { - return block.EstimatedSizeBytes; - } - } - - return -1L; - } - - public bool IsBlockActive(int blockIndex) - { - return _activeBlocks.ContainsKey(blockIndex); - } - - public bool IsBlockActive(int blockIndex, IComputationHandler handler) - { - if (!_activeBlocks.ContainsKey(blockIndex)) - { - return false; - } - - foreach (RecordBlock block in _activeBlocks[blockIndex]) - { - if (ReferenceEquals(block.Handler, handler)) - { - return true; - } - } - - return false; - } - - private bool IsBlockCached(int blockIndex, IComputationHandler handler) - { - if (!_cachedBlocks.ContainsKey(blockIndex)) - { - return false; - } - - foreach (WeakRecordBlock block in _cachedBlocks[blockIndex]) - { - if (ReferenceEquals(block.Handler, handler)) - { - return true; - } - } - - return false; - } - - public void Dispose() - { - foreach (IRecordExtractor extractor in _recordExtractors) - { - extractor.Dispose(); - extractor.Reader?.Dispose(); - } - - _cacheProvider.Dispose(); - } - - public static IDataset[] SplitBlockwise(IDataset dataset, params int[] parts) - { - if (parts.Length == 0) - { - throw new ArgumentException("Parts cannot be an empty collection."); - } - - int splitInterval = parts.Sum(); - int lastEnd = 0; - IDataset[] slices = new IDataset[parts.Length]; - - for (int i = 0; i < parts.Length; i++) - { - slices[i] = new DatasetBlockwiseSlice(dataset, lastEnd, lastEnd + parts[i] - 1, splitInterval); - lastEnd += parts[i]; - } - - return slices; - } - - public static IDataset[] SplitRecordwise(IDataset dataset, params double[] parts) - { - if (parts.Length == 0) - { - throw new ArgumentException("Percentages cannot be an empty collection."); - } - - if (parts.Sum() > 1.0) - { - throw new ArgumentException($"Percentages sum cannot be > 1.0, but parts sum was {parts.Sum()}."); - } - - IDataset[] slices = new IDataset[parts.Length]; - - double lastOffset = 0.0; - - for (int i = 0; i < slices.Length; i++) - { - slices[i] = new DatasetRecordwiseSlice(dataset, lastOffset, parts[i]); - - lastOffset += parts[i]; - } - - return slices; - } - - internal abstract class RecordBlockBase - { - internal abstract Dictionary NamedBlockSections { get; set; } - internal abstract INDArray FirstNamedBlock { get; set; } - internal abstract bool Loaded { get; set; } - - internal IComputationHandler Handler; - internal bool Active; - internal int BlockIndex; - internal long NumberRecords; - internal long EstimatedSizeBytes; - } - - internal class RecordBlock : RecordBlockBase - { - internal sealed override Dictionary NamedBlockSections { get; set; } - internal sealed override INDArray FirstNamedBlock { get; set; } - internal override bool Loaded { get; set; } - - public RecordBlock(Dictionary namedBlockSections, int blockIndex, long numberRecords, long estimatedSizeBytes, IComputationHandler handler) - { - NamedBlockSections = namedBlockSections; - BlockIndex = blockIndex; - NumberRecords = numberRecords; - EstimatedSizeBytes = estimatedSizeBytes; - Handler = handler; - - //record blocks internal block can be null - if (namedBlockSections != null) - { - FirstNamedBlock = namedBlockSections[namedBlockSections.First().Key]; - } - } - } - - internal class WeakRecordBlock : RecordBlockBase - { - internal override Dictionary NamedBlockSections - { - get - { - Dictionary target; - - return _namedBlockSections.TryGetTarget(out target) ? target : null; - } - set - { - _namedBlockSections.SetTarget(value); - } - } - - internal override INDArray FirstNamedBlock - { - get - { - INDArray target; - - return _firstNamedBlock.TryGetTarget(out target) ? target : null; - } - set - { - _firstNamedBlock.SetTarget(value); - } - } - - internal override bool Loaded - { - get - { - Dictionary target; - - return _namedBlockSections.TryGetTarget(out target); - } - set - { - } - } - - private readonly WeakReference> _namedBlockSections; - private readonly WeakReference _firstNamedBlock; - - public WeakRecordBlock(Dictionary namedBlockSections, int blockIndex, long numberRecords, long estimatedSizeBytes, IComputationHandler handler) - { - _namedBlockSections = new WeakReference>(namedBlockSections); - BlockIndex = blockIndex; - NumberRecords = numberRecords; - EstimatedSizeBytes = estimatedSizeBytes; - Handler = handler; - - //record blocks internal block can be null - if (namedBlockSections != null) - { - _firstNamedBlock = new WeakReference(namedBlockSections[namedBlockSections.First().Key]); - } - } - } - - public override string ToString() - { - return $"dataset \"{Name}\""; - } - } -} diff --git a/Sigma.Core/Data/Datasets/DatasetBlockwiseSlice.cs b/Sigma.Core/Data/Datasets/DatasetBlockwiseSlice.cs index 22797aa7..31509525 100644 --- a/Sigma.Core/Data/Datasets/DatasetBlockwiseSlice.cs +++ b/Sigma.Core/Data/Datasets/DatasetBlockwiseSlice.cs @@ -50,7 +50,7 @@ public long MaxBytesInCache } public string[] SectionNames => UnderlyingDataset.SectionNames; public IReadOnlyCollection ActiveBlockIndices => UnderlyingDataset.ActiveBlockIndices; - public int ActiveIndividualBlockCount => UnderlyingDataset.ActiveIndividualBlockCount; + public int ActiveIndividualBlockRegionCount => UnderlyingDataset.ActiveIndividualBlockRegionCount; public int ActiveBlockRegionCount => UnderlyingDataset.ActiveBlockRegionCount; /// @@ -112,12 +112,12 @@ protected int MapToUnderlyingIndex(int blockIndex) public IDataset[] SplitBlockwise(params int[] parts) { - return Dataset.SplitBlockwise(this, parts); + return ExtractedDataset.SplitBlockwise(this, parts); } public IDataset[] SplitRecordwise(params double[] parts) { - return Dataset.SplitRecordwise(this, parts); + return ExtractedDataset.SplitRecordwise(this, parts); } public bool TrySetBlockSize(int blockSizeRecords) diff --git a/Sigma.Core/Data/Datasets/DatasetRecordwiseSlice.cs b/Sigma.Core/Data/Datasets/DatasetRecordwiseSlice.cs index baf2fe84..f8714d6b 100644 --- a/Sigma.Core/Data/Datasets/DatasetRecordwiseSlice.cs +++ b/Sigma.Core/Data/Datasets/DatasetRecordwiseSlice.cs @@ -48,7 +48,7 @@ public long MaxBytesInCache } public string[] SectionNames => UnderlyingDataset.SectionNames; public IReadOnlyCollection ActiveBlockIndices => UnderlyingDataset.ActiveBlockIndices; - public int ActiveIndividualBlockCount => UnderlyingDataset.ActiveIndividualBlockCount; + public int ActiveIndividualBlockRegionCount => UnderlyingDataset.ActiveIndividualBlockRegionCount; public int ActiveBlockRegionCount => UnderlyingDataset.ActiveBlockRegionCount; /// @@ -89,12 +89,12 @@ public DatasetRecordwiseSlice(IDataset underlyingDataset, double shareOffset, do public IDataset[] SplitBlockwise(params int[] parts) { - return Dataset.SplitBlockwise(this, parts); + return ExtractedDataset.SplitBlockwise(this, parts); } public IDataset[] SplitRecordwise(params double[] parts) { - return Dataset.SplitRecordwise(this, parts); + return ExtractedDataset.SplitRecordwise(this, parts); } public bool TrySetBlockSize(int blockSizeRecords) diff --git a/Sigma.Core/Data/Datasets/ExtractedDataset.cs b/Sigma.Core/Data/Datasets/ExtractedDataset.cs new file mode 100644 index 00000000..9ca598a9 --- /dev/null +++ b/Sigma.Core/Data/Datasets/ExtractedDataset.cs @@ -0,0 +1,1069 @@ +/* +MIT License + +Copyright (c) 2016-2017 Florian Cäsar, Michael Plainer + +For full license see LICENSE in the root directory of this project. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using log4net; +using Sigma.Core.Data.Extractors; +using Sigma.Core.Handlers; +using Sigma.Core.MathAbstract; +using Sigma.Core.Persistence; +using Sigma.Core.Utils; + +namespace Sigma.Core.Data.Datasets +{ + /// + /// A default implementation of the IDataset interface. + /// Provides caching of entire blocks and reader data, partial extraction, unordered extraction, automatic block sizing, smart block loading. + /// + [Serializable] + public class ExtractedDataset : IDataset, ISerialisationNotifier + { + [NonSerialized] + private readonly ILog _logger = LogManager.GetLogger(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType); + + /// + /// Automatically size blocks according to estimated data metrics (e.g. physical memory available, record size). + /// + public const int BlockSizeAuto = -1; + + /// + /// Assign all available data to the first block (one block fits it all - literally). + /// + public const int BlockSizeAll = -2; + + /// + public string Name { get; } + + /// + /// Indicate if this dataset is an online dataset (meaning new data might be added during runtime). + /// By default, this is assumed to be false, indicating a static dataset. + /// Note: Data iterators and may perform certain optimisations for static datasets, so set this to false if possible. + /// + public bool Online { get; set; } = false; + + /// + public int MaxConcurrentActiveBlocks { get; } = 24; //24 seems like a good number, right? + + /// + public long MaxTotalActiveBlockSizeBytes { get; } = SystemInformationUtils.GetAvailablePhysicalMemoryBytes() / 2; //default to half the available physical memory + + /// + public IReadOnlyCollection ActiveBlockIndices => _activeBlocks.Keys.ToList(); + + /// + public int ActiveBlockRegionCount => _activeBlocks.Count; + + /// + public int ActiveIndividualBlockRegionCount { get { return _activeBlocks.Values.Sum(set => set.Count); } } + + /// + public int TargetBlockSizeRecords { get; private set; } + + /// + public string[] SectionNames { get; private set; } + + /// + public long TotalActiveBlockSizeBytes { get; private set; } + + public long TotalActiveRecords { get; private set; } + + /// + public int MaxBlocksInCache { get; set; } = int.MaxValue; + + /// + public long MaxBytesInCache { get; set; } = long.MaxValue; + + /// + /// Indicate whether this dataset should cache the raw reader data. + /// If disabled, only extracted data will be cached and once processed, it might be impossible to retrieve preceding record blocks (reader streams are assumed to be non-seekable). + /// + public bool AllowRawReadDataCaching { get; set; } = true; + + private readonly Dictionary> _activeBlocks; + private readonly Dictionary> _cachedBlocks; + private readonly ICacheProvider _cacheProvider; + + private int _lastReadRawDataBlockIndex = -1; + private long _totalCachedBlockSizeBytes; + private int _lastAvailableBlockIndex = int.MaxValue; + private readonly ISet _recordExtractors; + + private readonly bool _autoSetBlockSize; + private bool _autoSetExternalChangeBlockSize; + + // TODO fix available blocks semaphore logic + // the waitones/releases are inconsistent, because blocks aren't always actually allocated, such as null returns are not considered + [NonSerialized] + private Semaphore _availableBlocksSemaphore; + private int _availableBlocksSemaphoreState; + + /// + /// Create a dataset with a certain unique name and the record extractors to use. + /// + /// The unique dataset name. + /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks. + public ExtractedDataset(string name, params IRecordExtractor[] recordExtractors) : this(name, BlockSizeAuto, recordExtractors) + { + } + + /// + /// Create a dataset with a certain unique name, target block size in records and the record extractors to use. + /// + /// The unique dataset name. + /// The target block size for records. May also be or . + /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks. + public ExtractedDataset(string name, int blockSizeRecords, params IRecordExtractor[] recordExtractors) + : this(name, blockSizeRecords, true, recordExtractors) + { + } + + /// + /// Create a dataset with a certain unique name, target block size in records and the record extractors to use. + /// + /// The unique dataset name. + /// The target block size for records. May also be or . + /// Indicate whether the cache provider should be flushed (cleared) before use. Only disable if block size and extractors used do not change (otherwise undefined behaviour). + /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks. + public ExtractedDataset(string name, int blockSizeRecords, bool flushCache, params IRecordExtractor[] recordExtractors) + : this(name, blockSizeRecords, new DiskCacheProvider(SigmaEnvironment.Globals.Get("cache_path") + name), true, recordExtractors) + { + } + + /// + /// Create a dataset with a certain unique name, target block size in records, specific cache provider and the record extractors to use. + /// + /// The unique dataset name. + /// The target block size for records. May also be or . + /// The cache provider to use for caching record blocks and raw reader data. + /// Indicate whether the cache provider should be flushed (cleared) before use. Only disable if block size and extractors used do not change (otherwise undefined behaviour). + /// The record extractors to fetch the data from, which provide the dataset with ready to use record blocks. + public ExtractedDataset(string name, int blockSizeRecords, ICacheProvider cacheProvider, bool flushCache = true, params IRecordExtractor[] recordExtractors) + { + if (name == null) + { + throw new ArgumentNullException(nameof(name)); + } + + if (recordExtractors == null) + { + throw new ArgumentNullException(nameof(recordExtractors)); + } + + if (recordExtractors.Length == 0) + { + throw new ArgumentException("Datasets require at least one record extractor, but none were given."); + } + + if (cacheProvider == null) + { + throw new ArgumentNullException(nameof(cacheProvider)); + } + + switch (blockSizeRecords) + { + case BlockSizeAll: + //just set to maximum amount of records, extracting returns the maximum available anyway and we can't know the actual availability yet + TargetBlockSizeRecords = int.MaxValue; + break; + case BlockSizeAuto: + //somewhat temporary guesstimate, should probably expose the individual parameters + const long estimatedRecordSizeBytes = 1024; + const double memoryToConsume = 0.2f; + const long optimalNumberBlocks = 8; + const int maxBlockSizeRecords = 4096; + long availableSystemMemory = SystemInformationUtils.GetAvailablePhysicalMemoryBytes(); + + TargetBlockSizeRecords = Math.Min(maxBlockSizeRecords, (int)(availableSystemMemory * memoryToConsume / estimatedRecordSizeBytes / optimalNumberBlocks)); + + _autoSetBlockSize = true; + break; + default: + if (blockSizeRecords == 0 || blockSizeRecords < -2) + { + throw new ArgumentException($"Block size in records must be either BLOCK_SIZE_ALL, BLOCK_SIZE_AUTO or > 0, but given block size was {blockSizeRecords}."); + } + else + { + TargetBlockSizeRecords = blockSizeRecords; + } + break; + } + + Name = name; + AnalyseExtractors(recordExtractors); + + _cacheProvider = cacheProvider; + _recordExtractors = new HashSet(recordExtractors); + + _availableBlocksSemaphore = new Semaphore(MaxConcurrentActiveBlocks, MaxConcurrentActiveBlocks); + _availableBlocksSemaphoreState = MaxConcurrentActiveBlocks; + + _activeBlocks = new Dictionary>(); + _cachedBlocks = new Dictionary>(); + + if (flushCache) + { + _logger.Debug($"Flushing all caches for dataset \"{Name}\" as flushCache flag was set..."); + + InvalidateAndClearCaches(); + + _logger.Debug($"Done flushing all caches for dataset \"{Name}.\""); + } + } + + /// + /// Called before this object is serialised. + /// + public void OnSerialising() + { + } + + /// + /// Called after this object was serialised. + /// + public void OnSerialised() + { + } + + /// + /// Called after this object was de-serialised. + /// + public void OnDeserialised() + { + InvalidateAndClearCaches(); + _availableBlocksSemaphore = new Semaphore(MaxConcurrentActiveBlocks - ActiveIndividualBlockRegionCount, MaxConcurrentActiveBlocks); + } + + public IDataset[] SplitBlockwise(params int[] parts) + { + return SplitBlockwise(this, parts); + } + + public IDataset[] SplitRecordwise(params double[] parts) + { + return SplitRecordwise(this, parts); + } + + public bool TrySetBlockSize(int blockSizeRecords) + { + if (blockSizeRecords == TargetBlockSizeRecords) + { + //nothing to do here + return true; + } + + if (!_autoSetBlockSize) + { + _logger.Debug($"Cannot change block size as block size was not set automatically (attempted to change block size to {blockSizeRecords}."); + + return false; + } + + if (_activeBlocks.Count > 0 || _cachedBlocks.Count > 0) + { + _logger.Debug($"Cannot change block size as {_activeBlocks.Count + _cachedBlocks.Count} blocks were already fetched and are active or cached."); + + return false; + } + + if (_autoSetExternalChangeBlockSize && blockSizeRecords != TargetBlockSizeRecords) + { + _logger.Debug($"Cannot change block size to {blockSizeRecords}, block size is incompatible with another external block size change request (other request: {TargetBlockSizeRecords})"); + + return false; + } + + _autoSetExternalChangeBlockSize = true; + TargetBlockSizeRecords = blockSizeRecords; + + return true; + } + + private void AnalyseExtractors(IEnumerable extractors) + { + ISet sectionNames = new HashSet(); + + int index = 0; + foreach (IRecordExtractor extractor in extractors) + { + if (extractor == null) + { + throw new ArgumentNullException($"Extractor at index {index} was null."); + } + + if (extractor.SectionNames == null) + { + throw new ArgumentNullException($"Section names field in extractor {extractor} was null (field has to be set by extractor)."); + } + + string[] extractorSectionNames = extractor.SectionNames; + + foreach (string sectionName in extractorSectionNames) + { + if (sectionNames.Contains(sectionName)) + { + throw new ArgumentException($"Section name collision: duplicate section name {sectionName} detected for extractor {extractor}."); + } + else + { + sectionNames.Add(sectionName); + } + } + + index++; + } + + SectionNames = sectionNames.ToArray(); + } + + public int GetNumberOfLoadedInactiveCachedBlocks() + { + return _cachedBlocks.Values.SelectMany(blockSet => blockSet).Count(block => block.Loaded); + } + + public bool CanFetchBlocksAfter(int blockIndex) + { + return blockIndex <= _lastAvailableBlockIndex; + } + + public async Task> FetchBlockAsync(int blockIndex, IComputationHandler handler, bool shouldWaitUntilAvailable = true) + { + //TODO check if block even could be fetched to not waste thread resources if shouldWaitUntilAvailable is false anyway + + return await Task.Run(() => FetchBlock(blockIndex, handler, shouldWaitUntilAvailable)); + } + + public IDictionary FetchBlock(int blockIndex, IComputationHandler handler, bool shouldWaitUntilAvailable = true) + { + Dictionary block = FetchBlockConstrained(blockIndex, handler); + + //block could be fetched directly without violating any constraints, return successfully + if (block != null) + { + if (block.Count == 0) + { + throw new InvalidOperationException("Fetched block did not contain any named elements (was empty; is the extractor output correct?)."); + } + + RegisterActiveBlock(block, blockIndex, handler); + + return block; + } + else + { + if (blockIndex >= _lastAvailableBlockIndex) + { + return null; + } + + if (shouldWaitUntilAvailable) + { + _logger.Debug($"Could not directly load block with index {blockIndex} for handler {handler} and shouldWaitUntilAvailable flag is set to true, waiting for available space..."); + + return FetchBlockWhenAvailable(blockIndex, handler); + } + else + { + return null; + } + } + } + + private void RegisterActiveBlock(Dictionary block, int blockIndex, IComputationHandler handler) + { + INDArray firstNamedBlock = block[block.First().Key]; + + if (IsBlockActive(blockIndex, handler)) + { + //block already registered as active, nothing to do here + return; + } + + RecordBlock recordBlock = new RecordBlock(block, blockIndex, firstNamedBlock.Shape[0], handler.GetSizeBytes(block.Values.ToArray()), handler) + { Loaded = true, Active = true }; + + lock (_activeBlocks) + { + TotalActiveBlockSizeBytes += recordBlock.EstimatedSizeBytes; + TotalActiveRecords += recordBlock.NumberRecords; + + if (!_activeBlocks.ContainsKey(blockIndex)) + { + _activeBlocks.Add(blockIndex, new HashSet()); + } + + _activeBlocks[blockIndex].Add(recordBlock); + } + } + + private void DeregisterActiveBlock(RecordBlock recordBlock) + { + if (!IsBlockActive(recordBlock.BlockIndex, recordBlock.Handler)) + { + //block that should be de-registered is not even registered + return; + } + + lock (this) + { + TotalActiveBlockSizeBytes -= recordBlock.EstimatedSizeBytes; + TotalActiveRecords -= recordBlock.NumberRecords; + + _activeBlocks[recordBlock.BlockIndex].Remove(recordBlock); + + if (_activeBlocks[recordBlock.BlockIndex].Count == 0) + { + _activeBlocks.Remove(recordBlock.BlockIndex); + } + } + } + + private void RegisterCachedBlock(Dictionary block, int blockIndex, IComputationHandler handler, bool keepReference) + { + if (IsBlockCached(blockIndex, handler)) + { + //block's already cached, nothing to do here + return; + } + + if (!_cachedBlocks.ContainsKey(blockIndex)) + { + _cachedBlocks.Add(blockIndex, new HashSet()); + } + + WeakRecordBlock recordBlock = new WeakRecordBlock(keepReference ? block : null, blockIndex, block.First().Value.Shape[0], handler.GetSizeBytes(block.Values.ToArray()), handler); + + recordBlock.Loaded = false; + + _cachedBlocks[blockIndex].Add(recordBlock); + } + + /// + /// Invalidate and clear all caches associated with this dataset. + /// WARNING: Removing cache entries may cause certain datasets to load much more slowly or even incorrectly. + /// Legitimate use cases include removing cache entries for old datasets or changing extractors. + /// + public void InvalidateAndClearCaches() + { + _logger.Debug("Invalidating and clearing all caches..."); + + _cacheProvider.RemoveAll(); + _cachedBlocks.Clear(); + _totalCachedBlockSizeBytes = 0L; + + _logger.Debug("Done invalidating and clearing all caches."); + } + + private Dictionary FetchBlockWhenAvailable(int blockIndex, IComputationHandler handler) + { + while (true) + { + _logger.Debug($"Attempting to extract block region for request for block index {blockIndex} for handler {handler}, checking if it fits all constraints..."); + + Dictionary block = FetchBlockConstrained(blockIndex, handler); + + //if block != null we could fetch the block successfully without violating any constraints + if (block != null) + { + RegisterActiveBlock(block, blockIndex, handler); + + return block; + } + else + { + //we cannot retrieve any more blocks and shouldn't keep trying + if (blockIndex >= _lastAvailableBlockIndex) + { + return null; + } + + _logger.Debug($"Request for block with index {blockIndex} for handler {handler} was returned to the queue, seems to be violating constraints..."); + } + } + } + + private Dictionary FetchBlockConstrained(int blockIndex, IComputationHandler handler) + { + if (ActiveIndividualBlockRegionCount >= MaxConcurrentActiveBlocks) + { + _logger.Debug($"Unable to fetch block due to MaxConcurrentActiveBlocks constraint of {MaxConcurrentActiveBlocks}."); + + return null; + } + + Dictionary block = LoadAndExtractBlockWhenAvailable(blockIndex, handler); + + //there was nothing to load and extract, most likely end of stream + if (block == null) + { + return null; + } + + long blockSizeBytes = handler.GetSizeBytes(block.Values.ToArray()); + + if (TotalActiveBlockSizeBytes + blockSizeBytes > MaxTotalActiveBlockSizeBytes) + { + _logger.Debug($"Unable to keep requested block {blockIndex} for handler {handler} in memory due to MaxTotalActiveBlockSizeBytes constraint of {MaxTotalActiveBlockSizeBytes} bytes (block of size {blockSizeBytes} would exceed constraint by {TotalActiveBlockSizeBytes + blockSizeBytes - MaxTotalActiveBlockSizeBytes} bytes.)."); + + CacheBlockConstrained(block, blockIndex, handler); + + return null; + } + + return block; + } + + private Dictionary LoadAndExtractBlockWhenAvailable(int blockIndex, IComputationHandler handler) + { + //this method takes care of + // - checking whether the index is already loaded and active and then converts it + // - or checking whether the index is already cached in the right format and loads + // - or if none of that, loads and extracts from the original extractors + + //check whether a block with the same index and format is already active + if (_activeBlocks.ContainsKey(blockIndex)) + { + Dictionary block = GetBestMatchedBlockWhenAvailable(_activeBlocks[blockIndex], handler); + + if (block != null) + { + return block; + } + } + + //check whether a block with the same index and format is already loaded and cached but not active + if (_cachedBlocks.ContainsKey(blockIndex)) + { + Dictionary block = GetBestMatchedBlockWhenAvailable(_cachedBlocks[blockIndex], handler); + + if (block != null) + { + return block; + } + } + + lock (_cacheProvider) + { + string blockIdentifierInCache = $"extracted.{blockIndex}.{handler.DataType.Identifier}"; + + //check whether a block of the same index and format is cached in the cache provider + if (_cacheProvider.IsCached(blockIdentifierInCache)) + { + Dictionary block = _cacheProvider.Load>(blockIdentifierInCache); + + //if its != null we could read it correctly in the right format + if (block != null) + { + //register this cache entry as a properly loaded block in case the cache wasn't flushed and the cache map is outdated + RegisterCachedBlock(block, blockIndex, handler, keepReference: false); + + return block; + } + } + } + + //_availableBlocksSemaphore.WaitOne(); + //_availableBlocksSemaphoreState--; + + return LoadAndExtractRaw(blockIndex, handler); + } + + private Dictionary GetBestMatchedBlockWhenAvailable(IEnumerable blocks, IComputationHandler handler) + { + RecordBlockBase bestMatchedBlock = null; + + foreach (RecordBlockBase otherBlock in blocks) + { + if (otherBlock.Loaded && handler.CanConvert(otherBlock.FirstNamedBlock, otherBlock.Handler)) + { + if (handler.IsInterchangeable(otherBlock.Handler)) + { + //no need to look any further, we already found the perfect match and can return without conversion + return otherBlock.NamedBlockSections; + } + + bestMatchedBlock = otherBlock; + } + } + + if (bestMatchedBlock == null) + { + return null; + } + + //_availableBlocksSemaphore.WaitOne(); + //_availableBlocksSemaphoreState--; + + return ConvertNamedBlocks(bestMatchedBlock.NamedBlockSections, handler); + } + + private static Dictionary ConvertNamedBlocks(Dictionary namedBlockSections, IComputationHandler handler) + { + Dictionary convertedNamedBlocks = new Dictionary(); + + foreach (string name in namedBlockSections.Keys) + { + convertedNamedBlocks.Add(name, handler.Convert(namedBlockSections[name], handler)); + } + + return convertedNamedBlocks; + } + + private Dictionary LoadAndExtractRaw(int blockIndex, IComputationHandler handler) + { + // this cannot run concurrently as cache entries can only be read and written once without wasting resources and / or corrupting cache state + lock (this) + { + if (blockIndex > _lastReadRawDataBlockIndex) + { + object[] lastRawData = null; + + for (int tempBlockIndex = _lastReadRawDataBlockIndex + 1; tempBlockIndex <= blockIndex; tempBlockIndex++) + { + lastRawData = LoadDirect(tempBlockIndex, handler); + + //looks like we couldn't read any more blocks, maybe reached the end of the underlying source streams + if (lastRawData == null) + { + return null; + } + + if (AllowRawReadDataCaching) + { + _cacheProvider.Store($"raw.{tempBlockIndex}", lastRawData); + } + } + + return ExtractDirectFrom(lastRawData, blockIndex, handler); + } + else + { + if (AllowRawReadDataCaching) + { + string cacheIdentifier = $"raw.{blockIndex}"; + + if (!_cacheProvider.IsCached(cacheIdentifier)) + { + throw new InvalidOperationException($"Unable to load cached entry for block {blockIndex} for handler {handler}, cache entry does not exist in provider {_cacheProvider}."); + } + + return ExtractDirectFrom(_cacheProvider.Load(cacheIdentifier), blockIndex, handler); + } + else + { + throw new InvalidOperationException($"Cannot load and extract raw block with index {blockIndex} because AllowRawReadDataCaching is set to false and last read position is at {_lastReadRawDataBlockIndex}."); + } + } + } + } + + private object[] LoadDirect(int blockIndex, IComputationHandler handler) + { + IList rawDataPerExtractor = new List(); + + PrepareExtractors(); + + foreach (IRecordExtractor extractor in _recordExtractors) + { + object data; + + lock (extractor.Reader) + { + data = extractor.Reader.Read(TargetBlockSizeRecords); + } + + //check if block reader could read anything, if not, return null + if (data == null) + { + _lastAvailableBlockIndex = blockIndex - 1; + + _logger.Debug($"Cannot load block {blockIndex} for handler {handler}, the underlying stream for extractor {extractor} is unable to retrieve any more records. End of stream most likely reached."); + + return null; + } + + rawDataPerExtractor.Add(data); + } + + if (blockIndex > _lastReadRawDataBlockIndex) + { + _lastReadRawDataBlockIndex = blockIndex; + } + + return rawDataPerExtractor.ToArray(); + } + + private Dictionary ExtractDirectFrom(object[] data, int blockIndex, IComputationHandler handler) + { + Dictionary namedBlocks = new Dictionary(); + + ITaskObserver prepareTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Prepare, "preparing extractors for dataset \"" + Name + "\"", indeterminate: true); + + PrepareExtractors(); + + SigmaEnvironment.TaskManager.EndTask(prepareTask); + + ITaskObserver extractTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Extract, $"extracting block {blockIndex} for dataset \"{Name}\"", indeterminate: true); + + int extractorIndex = 0; + foreach (IRecordExtractor extractor in _recordExtractors) + { + _logger.Debug($"Extracting hierarchically from extractor {extractor} at index {extractorIndex}..."); + + Dictionary subNamedBlock = extractor.ExtractHierarchicalFrom(data[extractorIndex++], TargetBlockSizeRecords, handler); + + //check if block size is 0, indicating we reached the end of the stream + if (subNamedBlock == null) + { + _lastAvailableBlockIndex = blockIndex - 1; + + _logger.Debug($"Cannot extract block {blockIndex} for handler {handler}, the underlying stream for extractor {extractor} is unable to retrieve any more records. End of stream most likely reached."); + + SigmaEnvironment.TaskManager.CancelTask(extractTask); + + return null; + } + + foreach (string name in subNamedBlock.Keys) + { + if (namedBlocks.ContainsKey(name)) + { + SigmaEnvironment.TaskManager.CancelTask(extractTask); + + throw new ArgumentException($"Section name collision: {name} is already used by another extractor, current extractor {extractor} cannot use it again."); + } + else + { + namedBlocks.Add(name, subNamedBlock[name]); + } + } + } + + SigmaEnvironment.TaskManager.EndTask(extractTask); + + return namedBlocks; + } + + public void FreeBlock(int blockIndex, IComputationHandler handler) + { + lock (_activeBlocks) + { + if (!_activeBlocks.ContainsKey(blockIndex)) + { + _logger.Debug($"Unable to free block with index {blockIndex} for handler {handler} because no block with that information is currently active."); + + return; + } + + RecordBlock toRemove = null; + + foreach (RecordBlock block in _activeBlocks[blockIndex]) + { + if (ReferenceEquals(block.Handler, handler)) + { + _logger.Debug($"Freeing block with index {blockIndex} for handler {handler}..."); + + CacheBlockConstrained(block.NamedBlockSections, blockIndex, handler); + + //_availableBlocksSemaphore.Release(); + //_availableBlocksSemaphoreState++; + + toRemove = block; + + goto FoundBlock; + } + } + + _logger.Debug($"Unable to free block with index {blockIndex} for handler {handler} because no block with that information is currently active."); + + FoundBlock: + + DeregisterActiveBlock(toRemove); + _logger.Debug($"Done freeing block with index {blockIndex} for handler {handler}."); + + } + } + + private void CacheBlockConstrained(Dictionary block, int blockIndex, IComputationHandler handler) + { + if (_cachedBlocks.ContainsKey(blockIndex)) + { + foreach (WeakRecordBlock cachedBlock in _cachedBlocks[blockIndex]) + { + //check if block of the same type and size is already cached, if so, return, because there is no need to cache again + if (cachedBlock.BlockIndex == blockIndex && cachedBlock.Handler.IsInterchangeable(handler) && block.First().Value.Shape[0] == cachedBlock.NumberRecords) + { + _logger.Debug($"Skipping cache request of block {blockIndex} for handler {handler} because interchangeable block of same index, format and size is already cached."); + + return; + } + } + } + + long blockSizeBytes = handler.GetSizeBytes(block.Values.ToArray()); + + if (_cachedBlocks.Count >= MaxBlocksInCache) + { + _logger.Debug($"Unable to cache block {blockIndex} for handler {handler} due to MaxBlocksInCache constraint of {MaxBlocksInCache}."); + + return; + } + + if (blockSizeBytes + _totalCachedBlockSizeBytes >= MaxBytesInCache) + { + _logger.Debug($"Unable to cache block {blockIndex} for handler {handler} due to MaxBytesInCache constraint of {MaxBytesInCache} bytes (block of size {blockSizeBytes} would exceed constraint by {_totalCachedBlockSizeBytes + blockSizeBytes - MaxBytesInCache} bytes)."); + + return; + } + + string cacheIdentifier = $"extracted.{blockIndex}.{handler.DataType.Identifier}"; + + _cacheProvider.Store(cacheIdentifier, block); + + bool keepReference = TotalActiveBlockSizeBytes + blockSizeBytes < MaxTotalActiveBlockSizeBytes; + + RegisterCachedBlock(block, blockIndex, handler, keepReference); + + _totalCachedBlockSizeBytes += blockSizeBytes; + } + + private void PrepareExtractors() + { + foreach (IRecordExtractor extractor in _recordExtractors) + { + lock (extractor) + { + extractor.Prepare(); + } + } + } + + public long GetBlockSizeBytes(int blockIndex, IComputationHandler handler) + { + if (!_activeBlocks.ContainsKey(blockIndex)) + { + return -1L; + } + + foreach (RecordBlock block in _activeBlocks[blockIndex]) + { + if (ReferenceEquals(block.Handler, handler)) + { + return block.EstimatedSizeBytes; + } + } + + return -1L; + } + + public bool IsBlockActive(int blockIndex) + { + return _activeBlocks.ContainsKey(blockIndex); + } + + public bool IsBlockActive(int blockIndex, IComputationHandler handler) + { + if (!_activeBlocks.ContainsKey(blockIndex)) + { + return false; + } + + foreach (RecordBlock block in _activeBlocks[blockIndex]) + { + if (ReferenceEquals(block.Handler, handler)) + { + return true; + } + } + + return false; + } + + private bool IsBlockCached(int blockIndex, IComputationHandler handler) + { + if (!_cachedBlocks.ContainsKey(blockIndex)) + { + return false; + } + + foreach (WeakRecordBlock block in _cachedBlocks[blockIndex]) + { + if (ReferenceEquals(block.Handler, handler)) + { + return true; + } + } + + return false; + } + + public void Dispose() + { + foreach (IRecordExtractor extractor in _recordExtractors) + { + extractor.Dispose(); + extractor.Reader?.Dispose(); + } + + _cacheProvider.Dispose(); + } + + public static IDataset[] SplitBlockwise(IDataset dataset, params int[] parts) + { + if (parts.Length == 0) + { + throw new ArgumentException("Parts cannot be an empty collection."); + } + + int splitInterval = parts.Sum(); + int lastEnd = 0; + IDataset[] slices = new IDataset[parts.Length]; + + for (int i = 0; i < parts.Length; i++) + { + slices[i] = new DatasetBlockwiseSlice(dataset, lastEnd, lastEnd + parts[i] - 1, splitInterval); + lastEnd += parts[i]; + } + + return slices; + } + + public static IDataset[] SplitRecordwise(IDataset dataset, params double[] parts) + { + if (parts.Length == 0) + { + throw new ArgumentException("Percentages cannot be an empty collection."); + } + + if (parts.Sum() > 1.0) + { + throw new ArgumentException($"Percentages sum cannot be > 1.0, but parts sum was {parts.Sum()}."); + } + + IDataset[] slices = new IDataset[parts.Length]; + + double lastOffset = 0.0; + + for (int i = 0; i < slices.Length; i++) + { + slices[i] = new DatasetRecordwiseSlice(dataset, lastOffset, parts[i]); + + lastOffset += parts[i]; + } + + return slices; + } + + internal abstract class RecordBlockBase + { + internal abstract Dictionary NamedBlockSections { get; set; } + internal abstract INDArray FirstNamedBlock { get; set; } + internal abstract bool Loaded { get; set; } + + internal IComputationHandler Handler; + internal bool Active; + internal int BlockIndex; + internal long NumberRecords; + internal long EstimatedSizeBytes; + } + + internal class RecordBlock : RecordBlockBase + { + internal sealed override Dictionary NamedBlockSections { get; set; } + internal sealed override INDArray FirstNamedBlock { get; set; } + internal override bool Loaded { get; set; } + + public RecordBlock(Dictionary namedBlockSections, int blockIndex, long numberRecords, long estimatedSizeBytes, IComputationHandler handler) + { + NamedBlockSections = namedBlockSections; + BlockIndex = blockIndex; + NumberRecords = numberRecords; + EstimatedSizeBytes = estimatedSizeBytes; + Handler = handler; + + //record blocks internal block can be null + if (namedBlockSections != null) + { + FirstNamedBlock = namedBlockSections[namedBlockSections.First().Key]; + } + } + } + + internal class WeakRecordBlock : RecordBlockBase + { + internal override Dictionary NamedBlockSections + { + get + { + Dictionary target; + + return _namedBlockSections.TryGetTarget(out target) ? target : null; + } + set + { + _namedBlockSections.SetTarget(value); + } + } + + internal override INDArray FirstNamedBlock + { + get + { + INDArray target; + + return _firstNamedBlock.TryGetTarget(out target) ? target : null; + } + set + { + _firstNamedBlock.SetTarget(value); + } + } + + internal override bool Loaded + { + get + { + Dictionary target; + + return _namedBlockSections.TryGetTarget(out target); + } + set + { + } + } + + private readonly WeakReference> _namedBlockSections; + private readonly WeakReference _firstNamedBlock; + + public WeakRecordBlock(Dictionary namedBlockSections, int blockIndex, long numberRecords, long estimatedSizeBytes, IComputationHandler handler) + { + _namedBlockSections = new WeakReference>(namedBlockSections); + BlockIndex = blockIndex; + NumberRecords = numberRecords; + EstimatedSizeBytes = estimatedSizeBytes; + Handler = handler; + + //record blocks internal block can be null + if (namedBlockSections != null) + { + _firstNamedBlock = new WeakReference(namedBlockSections[namedBlockSections.First().Key]); + } + } + } + + public override string ToString() + { + return $"dataset \"{Name}\""; + } + } +} diff --git a/Sigma.Core/Data/Datasets/IDataset.cs b/Sigma.Core/Data/Datasets/IDataset.cs index a4ffc1c2..558b9fbc 100644 --- a/Sigma.Core/Data/Datasets/IDataset.cs +++ b/Sigma.Core/Data/Datasets/IDataset.cs @@ -76,7 +76,7 @@ public interface IDataset : IDisposable /// /// The number of currently active and loaded record blocks, with different block formats counting as different blocks. /// - int ActiveIndividualBlockCount { get; } + int ActiveIndividualBlockRegionCount { get; } /// /// The number of currently active and loaded record blocks, with different block formats of the same region counting as one active block index. diff --git a/Sigma.Core/Data/Datasets/RawDataset.cs b/Sigma.Core/Data/Datasets/RawDataset.cs new file mode 100644 index 00000000..5ac3e388 --- /dev/null +++ b/Sigma.Core/Data/Datasets/RawDataset.cs @@ -0,0 +1,335 @@ +/* +MIT License + +Copyright(c) 2016-2017 Florian Cäsar, Michael Plainer + +For full license see LICENSE in the root directory of this project. +*/ + + +using System; +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Sigma.Core.Handlers; +using Sigma.Core.Handlers.Backends.SigmaDiff.NativeCpu; +using Sigma.Core.MathAbstract; +using Sigma.Core.Utils; + +namespace Sigma.Core.Data.Datasets +{ + /// + /// A raw in-system-memory dataset which can be manually + /// + [Serializable] + public class RawDataset : IDataset + { + private readonly IComputationHandler _internalHandler; + + /// + /// The name and identifier of this dataset. + /// Dataset names should be globally unique and easily identifiable. + /// + public string Name { get; } + + /// + /// Indicate if this dataset is an online dataset (meaning new data might be added during runtime). + /// By default, this is assumed to be false, indicating a static dataset. + /// Note: Data iterators and may perform certain optimisations for static datasets, so set this to false if possible. + /// + public bool Online { get; set; } + + /// + /// The preferred per block size in records. + /// Note: Not every block must obey this request (e.g. the last black might very well be a different size). + /// + public int TargetBlockSizeRecords { get; } + + /// + /// The maximum number of concurrently active blocks. + /// + public int MaxConcurrentActiveBlocks { get; } + + /// + /// The maximum total concurrently active block size in bytes. + /// + public long MaxTotalActiveBlockSizeBytes { get; } + + /// + /// The total size of all currently active record blocks in system memory in bytes. + /// + public long TotalActiveBlockSizeBytes { get; } + + /// + /// The maxmimum number of blocks to keep in the cache (inactive blocks are written to a cache, typically on disk, to be reloaded later). + /// + public int MaxBlocksInCache { get { throw new NotSupportedException(); } set { throw new NotSupportedException(); } } + + /// + /// The maxmimum number of bytes to keep in the cache (inactive blocks are written to a cache, typically on disk, to be reloaded later). + /// + public long MaxBytesInCache { get { throw new NotSupportedException(); } set { throw new NotSupportedException(); } } + + /// + /// The names for all sections present in this dataset (e.g. "inputs", "targets"). + /// + public string[] SectionNames { get; } + + /// + /// A set of currently active and loaded record block indices. + /// + public IReadOnlyCollection ActiveBlockIndices { get; } + + /// + /// The number of currently active and loaded record blocks, with different block formats counting as different blocks. + /// + public int ActiveIndividualBlockRegionCount => _rawData.Count; + + /// + /// The number of currently active and loaded record blocks, with different block formats of the same region counting as one active block index. + /// + public int ActiveBlockRegionCount => _rawData.Count & 0x1; // can only be 1 active block + + /// + /// The current working data that can be edited and is "flushed" to the public raw data with the next call. + /// + private readonly IDictionary _internalWorkingData; + + /// + /// The raw data of this dataset that is returned via the functions. + /// + private readonly IDictionary> _rawData; + + /// + /// Create a raw dataset with a certain name and an internal cpu handler with 32-bit float precision. + /// + /// The globally unique name of this dataset. + public RawDataset(string name) : this(name, new CpuFloat32Handler()) + { + } + + /// + /// Create a raw dataset with a certain name and computation handler. + /// + /// The globally unique name of this dataset. + /// The internal handler to use for data management. + public RawDataset(string name, IComputationHandler internalHandler) + { + if (name == null) throw new ArgumentNullException(nameof(name)); + if (internalHandler == null) throw new ArgumentNullException(nameof(internalHandler)); + + Name = name; + _internalHandler = internalHandler; + + _internalWorkingData = new ConcurrentDictionary(); + _rawData = new ConcurrentDictionary>(); + } + + /// + /// Add a record to a certain block. + /// Note: Feature shape is length of record, no time dimension, auto batch dimension. + /// + /// The record data type (must be primitive data type). + /// The block name (e.g. "inputs"). + /// The record. + public void AddRecord(string blockName, params T[] record) + { + AddShapedRecords(blockName, new long[] { record.Length }, new[] { record }); + } + + /// + /// Add a record with a certain feature shape to a certain block. + /// Note: Feature shape is as specified, no time dimension, auto batch dimension. + /// + /// The record data type (must be primitive data type). + /// The block name (e.g. "inputs"). + /// The feature shape. + /// The record. + public void AddRecord(string blockName, long[] featureShape, params T[] record) + { + AddShapedRecords(blockName, featureShape, record); + } + + /// + /// Add records to a certain block. + /// Note: Feature shape is length of record, no time dimension, auto batch dimension. + /// + /// The record data type (must be primitive data type). + /// The block name (e.g. "inputs"). + /// The records. + public void AddRecords(string blockName, params T[][] records) + { + AddShapedRecords(blockName, new long[] { records[0].Length }, records); + } + + /// + /// Add records with a certain feature shape to a certain block. + /// Note: Feature shape is as specified, no time dimension, auto batch dimension. + /// + /// The record data type (must be primitive data type). + /// The block name (e.g. "inputs"). + /// The feature shape. + /// The records. + public void AddShapedRecords(string blockName, long[] featureShape, params T[][] records) + { + if (records.Length == 0) + { + return; + } + + long featureLength = ArrayUtils.Product(featureShape); + long[] newShape = ArrayUtils.Concatenate(new long[] { records.Length, 1 }, featureShape); // BatchTimeFeatures shape order, time dimension is not supported at the moment + long[] insertedShape = (long[])newShape.Clone(); + bool previousBlockExists = _internalWorkingData.ContainsKey(blockName); + + if (previousBlockExists) + { + newShape[0] += _internalWorkingData[blockName].Shape[0]; // append new record to end + } + + INDArray newBlock = _internalHandler.NDArray(newShape); + + long[] destinationBegin = new long[newBlock.Rank]; + + if (previousBlockExists) + { + INDArray oldBlock = _internalWorkingData[blockName]; + + for (int i = 1; i < oldBlock.Shape.Length; i++) + { + if (newShape[i] != oldBlock.Shape[i]) + { + throw new InvalidOperationException($"Shape mismatch: already existing block for \"{blockName}\" has shape {ArrayUtils.ToString(oldBlock.Shape)} but new block has shape {ArrayUtils.ToString(newShape)}"); + } + } + + long[] previousSourceBegin = new long[oldBlock.Rank]; + long[] previousSourceEnd = oldBlock.Shape.Select(i => i - 1).ToArray(); + + _internalHandler.Fill(oldBlock, newBlock, previousSourceBegin, previousSourceEnd, previousSourceBegin, previousSourceEnd); + + destinationBegin[0] = oldBlock.Shape[0]; + } + + long[] destinationEnd = insertedShape.Select(i => i - 1).ToArray(); + destinationEnd[0] = destinationBegin[0]; + + for (int i = 0; i < records.Length; i++) + { + _internalHandler.Fill(records[i], newBlock, destinationBegin, destinationEnd); + + destinationBegin[0]++; + destinationEnd[0]++; + } + + if (previousBlockExists) + { + _internalWorkingData[blockName] = newBlock; + } + else + { + _internalWorkingData.Add(blockName, newBlock); + } + } + + /// + public IDictionary FetchBlock(int blockIndex, IComputationHandler handler, bool shouldWaitUntilAvailable = true) + { + if (blockIndex != 0 || _internalWorkingData.Count == 0) + { + return null; // there is only 1 block in this raw dataset implementation (so if there's no data, there's no block) + } + + if (IsBlockActive(blockIndex, handler)) + { + return _rawData[handler]; + } + + if (!handler.CanConvert(_internalWorkingData.Values.First(), _internalHandler)) + { + return null; + } + + IDictionary convertedBlock = new Dictionary(); + + foreach (string blockName in _internalWorkingData.Keys) + { + convertedBlock[blockName] = handler.Convert(_internalWorkingData[blockName], _internalHandler); + } + + _rawData.Add(handler, convertedBlock); + + return convertedBlock; + } + + /// + public async Task> FetchBlockAsync(int blockIndex, IComputationHandler handler, bool shouldWaitUntilAvailable = true) + { + return await Task.Run(() => FetchBlock(blockIndex, handler, shouldWaitUntilAvailable)); + } + + /// + public void FreeBlock(int blockIndex, IComputationHandler handler) + { + if (IsBlockActive(blockIndex, handler)) + { + _rawData.Remove(handler); + } + } + + /// + public bool IsBlockActive(int blockIndex) + { + return blockIndex == 0; + } + + /// + public bool IsBlockActive(int blockIndex, IComputationHandler handler) + { + return blockIndex == 0 && _rawData.ContainsKey(handler); + } + + /// + public long GetBlockSizeBytes(int blockIndex, IComputationHandler handler) + { + if (!IsBlockActive(blockIndex, handler)) + { + return -1L; + } + + // TODO + throw new NotImplementedException(); + } + + /// + public bool CanFetchBlocksAfter(int blockIndex) + { + return blockIndex == -1; + } + + /// + public bool TrySetBlockSize(int blockSizeRecords) + { + throw new NotSupportedException(); + } + + /// + public IDataset[] SplitBlockwise(params int[] parts) + { + return ExtractedDataset.SplitBlockwise(this, parts); + } + + /// + public IDataset[] SplitRecordwise(params double[] percentages) + { + return ExtractedDataset.SplitRecordwise(this, percentages); + } + + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + public void Dispose() + { + } + } +} diff --git a/Sigma.Core/Data/Preprocessors/Adaptive/BaseAdaptivePreprocessor.cs b/Sigma.Core/Data/Preprocessors/Adaptive/BaseAdaptivePreprocessor.cs index f6b57adb..de71a009 100644 --- a/Sigma.Core/Data/Preprocessors/Adaptive/BaseAdaptivePreprocessor.cs +++ b/Sigma.Core/Data/Preprocessors/Adaptive/BaseAdaptivePreprocessor.cs @@ -28,12 +28,22 @@ public abstract class BaseAdaptivePreprocessor : BasePreprocessor private readonly TPreprocessor _underlyingPreprocessor; private bool _initialAdaptionComplete; + /// + /// Create a base adaptive preprocessor with a certain underlying preprocessor (that will be adapted). + /// + /// The underlying preprocessor. + /// The section names to process in this preprocessor (all if null or empty). + protected BaseAdaptivePreprocessor(TPreprocessor underlyingPreprocessor, params string[] sectionNames) : this(underlyingPreprocessor, AdaptionRate.Initial, sectionNames) + { + } + /// /// Create a base adaptive preprocessor with a certain underlying preprocessor (that will be adapted). /// /// The underlying preprocessor. /// The adaption rate. - protected BaseAdaptivePreprocessor(TPreprocessor underlyingPreprocessor, AdaptionRate adaptionRate = AdaptionRate.Every) + /// The section names to process in this preprocessor (all if null or empty). + protected BaseAdaptivePreprocessor(TPreprocessor underlyingPreprocessor, AdaptionRate adaptionRate, params string[] sectionNames) : base(sectionNames) { if (underlyingPreprocessor == null) throw new ArgumentNullException(nameof(underlyingPreprocessor)); if (!Enum.IsDefined(typeof(AdaptionRate), adaptionRate)) throw new InvalidEnumArgumentException(nameof(adaptionRate), (int) adaptionRate, typeof(AdaptionRate)); diff --git a/Sigma.Core/Data/Preprocessors/ShufflePreprocessor.cs b/Sigma.Core/Data/Preprocessors/ShufflePreprocessor.cs new file mode 100644 index 00000000..b2f84fbf --- /dev/null +++ b/Sigma.Core/Data/Preprocessors/ShufflePreprocessor.cs @@ -0,0 +1,75 @@ +using System; +using ManagedCuda.VectorTypes; +using Sigma.Core.Handlers; +using Sigma.Core.MathAbstract; + +namespace Sigma.Core.Data.Preprocessors +{ + [Serializable] + public class ShufflePreprocessor : BasePreprocessor + { + /// + /// The dimension along which should be shuffled. + /// Note: Must be >= 0. + /// + public int AlongDimension + { + get { return _alongDimension; } + set + { + if (value < 0) throw new ArgumentException($"Along dimension must be >= 0 but given value was {value}."); + + _alongDimension = value; + } + } + + /// + public override bool AffectsDataShape => false; + + private Random _random; + private int _alongDimension; + + /// + /// Create a shuffle preprocessor and optionally specify the dominant dimension (along which should be shuffled, batch dimension (0) by default). + /// + /// + public ShufflePreprocessor(int alongDimension = 0) + { + AlongDimension = alongDimension; + } + + /// + /// Process a certain ndarray with a certain computation handler. + /// + /// The ndarray to process. + /// The computation handler to do the processing with. + /// An ndarray with the processed contents of the given array (can be the same or a new one). + internal override INDArray ProcessDirect(INDArray array, IComputationHandler handler) + { + int recordLength = (int) (array.Length / array.Shape[0]); + long[] firstBufferIndices = new long[array.Shape.Length]; + long[] secondBufferIndices = new long[array.Shape.Length]; + + _random = new Random(31415926); // fixed rng for reproducability + + for (int i = 0; i < array.Shape[0]; i++) + { + int swapIndex = _random.Next((int) array.Shape[0]); + + for (int y = 0; y < recordLength; y++) + { + NDArrayUtils.GetIndices(recordLength * i + y, array.Shape, array.Strides, firstBufferIndices); + NDArrayUtils.GetIndices(recordLength * swapIndex + y, array.Shape, array.Strides, secondBufferIndices); + + double firstValue = array.GetValue(firstBufferIndices); + double secondValue = array.GetValue(secondBufferIndices); + + array.SetValue(secondValue, firstBufferIndices); + array.SetValue(firstValue, secondBufferIndices); + } + } + + return array; + } + } +} diff --git a/Sigma.Core/Dependencies/DiffSharp.dll b/Sigma.Core/Dependencies/DiffSharp.dll index ee9e31cf..0e868b5c 100644 Binary files a/Sigma.Core/Dependencies/DiffSharp.dll and b/Sigma.Core/Dependencies/DiffSharp.dll differ diff --git a/Sigma.Core/Handlers/Backends/Debugging/DebugHandler.cs b/Sigma.Core/Handlers/Backends/Debugging/DebugHandler.cs index 8453f69c..b2dc1ed3 100644 --- a/Sigma.Core/Handlers/Backends/Debugging/DebugHandler.cs +++ b/Sigma.Core/Handlers/Backends/Debugging/DebugHandler.cs @@ -76,8 +76,8 @@ public DebugHandler(IComputationHandler underlyingHandler, bool throwExceptionOn // kind of ugly but saves me from writing more solid property handling ThrowExceptionOnReport = throwExceptionOnReport; Enabled = enabled; - CheckNaN = false; - CheckInfinite = false; + CheckNaN = enabled; + CheckInfinite = enabled; } private void Report(string message, params object[] values) @@ -99,23 +99,23 @@ private INDArray CheckNice(INDArray array, string paramName = "unspecified") if (array == null) { - Report($"ndarray {paramName} is null."); + Report($"ndarray \"{paramName}\" is null."); } else { if (array.Rank != array.Shape.Length) { - Report($"ndarray {paramName} has inconsistent rank ({array.Rank}) / shape (length {array.Length}).", array); + Report($"ndarray \"{paramName}\" has inconsistent rank ({array.Rank}) / shape (length {array.Length}).", array); } - if (CheckNaN && !UnderlyingHandler.IsNaN(array)) + if (CheckNaN && UnderlyingHandler.IsNaN(array)) { - Report($"ndarray {paramName} contains NaN values.", array); + Report($"ndarray \"{paramName}\" contains NaN values.", array); } - if (CheckInfinite && !UnderlyingHandler.IsNotFinite(array)) + if (CheckInfinite && UnderlyingHandler.IsNotFinite(array)) { - Report($"ndarray {paramName} contains infinite values.", array); + Report($"ndarray \"{paramName}\" contains infinite values.", array); } } @@ -135,12 +135,12 @@ private INumber CheckNice(INumber number, string paramName = "unspecified") } else { - if (CheckNaN && !UnderlyingHandler.IsNaN(number)) + if (CheckNaN && UnderlyingHandler.IsNaN(number)) { Report($"number {paramName} is a NaN value.", number); } - if (CheckInfinite && !UnderlyingHandler.IsNotFinite(number)) + if (CheckInfinite && UnderlyingHandler.IsNotFinite(number)) { Report($"number {paramName} is an infinite value.", number); } @@ -278,11 +278,25 @@ public void Fill(INDArray filler, INDArray arrayToFill) public void Fill(TOther value, INDArray arrayToFill) { - UnderlyingHandler.Fill(value, CheckNice(arrayToFill)); + UnderlyingHandler.Fill(value, arrayToFill); CheckNice(arrayToFill); } + public void Fill(INDArray filler, INDArray arrayToFill, long[] sourceBeginIndices, long[] sourceEndIndices, long[] destinationBeginIndices, long[] destinationEndIndices) + { + UnderlyingHandler.Fill(CheckNice(filler), CheckNice(arrayToFill), sourceBeginIndices, sourceEndIndices, destinationBeginIndices, destinationEndIndices); + + CheckNice(arrayToFill); + } + + public void Fill(T[] filler, INDArray arrayToFill, long[] destinationBeginIndices, long[] destinationEndIndices) + { + UnderlyingHandler.Fill(filler, CheckNice(arrayToFill), destinationBeginIndices, destinationEndIndices); + + CheckNice(arrayToFill); + } + public INDArray FlattenTime(INDArray array) { if (Enabled && array.Rank < 2) // two or three? technically 2 is enough ([BT]F) but 3 makes more sense diff --git a/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpBackendHandle.cs b/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpBackendHandle.cs index d91d8eb8..4a9e14ab 100644 --- a/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpBackendHandle.cs +++ b/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpBackendHandle.cs @@ -53,9 +53,7 @@ internal DiffSharpBackendHandle(IBlasBackend blasBackend, ILapackBackend lapackB public abstract FSharpOption> Solve_M_V(ShapedDataBufferView a, ISigmaDiffDataBuffer b); public abstract FSharpOption> SolveSymmetric_M_V(ShapedDataBufferView a, ISigmaDiffDataBuffer b); public abstract ISigmaDiffDataBuffer Diagonal_M(ShapedDataBufferView a); - public abstract ISigmaDiffDataBuffer Map_F_V(FSharpFunc a, ISigmaDiffDataBuffer b); - public abstract ISigmaDiffDataBuffer Map2_F_V_V(FSharpFunc> a, ISigmaDiffDataBuffer b, ISigmaDiffDataBuffer obj2); - public abstract ISigmaDiffDataBuffer ReshapeCopy_MRows_V(ShapedDataBufferView value); + public abstract ISigmaDiffDataBuffer ReshapeCopy_MRows_V(ShapedDataBufferView value); public abstract ShapedDataBufferView Mul_Out_V_V(ISigmaDiffDataBuffer a, ISigmaDiffDataBuffer b); public abstract ShapedDataBufferView Add_M_M(ShapedDataBufferView a, ShapedDataBufferView b); public abstract ShapedDataBufferView Add_S_M(T a, ShapedDataBufferView b); @@ -70,10 +68,14 @@ internal DiffSharpBackendHandle(IBlasBackend blasBackend, ILapackBackend lapackB public abstract FSharpOption> Inverse_M(ShapedDataBufferView a); public abstract FSharpOption Det_M(ShapedDataBufferView a); public abstract ShapedDataBufferView Transpose_M(ShapedDataBufferView a); - public abstract ShapedDataBufferView Map_F_M(FSharpFunc a, ShapedDataBufferView b); - public abstract ShapedDataBufferView Map2_F_M_M(FSharpFunc> a, ShapedDataBufferView b, ShapedDataBufferView obj2); public abstract ShapedDataBufferView ReshapeCopy_V_MRows(int rows, ISigmaDiffDataBuffer value); public abstract ShapedDataBufferView RepeatReshapeCopy_V_MRows(int rows, ISigmaDiffDataBuffer value); public abstract ShapedDataBufferView RepeatReshapeCopy_V_MCols(int cols, ISigmaDiffDataBuffer value); + public abstract ISigmaDiffDataBuffer Map_F_V(MapOp mapOp, FSharpFunc function, ISigmaDiffDataBuffer value); + public abstract ISigmaDiffDataBuffer Map_F_S_V(T other, MapOp mapOp, FSharpFunc function, ISigmaDiffDataBuffer value); + public abstract ISigmaDiffDataBuffer Map2_F_V_V(MapOp mapOp, FSharpFunc> function, ISigmaDiffDataBuffer a, ISigmaDiffDataBuffer b); + public abstract ShapedDataBufferView Map_F_M(MapOp mapOp, FSharpFunc function, ShapedDataBufferView value); + public abstract ShapedDataBufferView Map_F_S_M(T other, MapOp mapOp, FSharpFunc function, ShapedDataBufferView value); + public abstract ShapedDataBufferView Map2_F_M_M(MapOp mapOp, FSharpFunc> function, ShapedDataBufferView a, ShapedDataBufferView b); } } diff --git a/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpFloat32BackendHandle.cs b/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpFloat32BackendHandle.cs index 43d023af..952a7102 100644 --- a/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpFloat32BackendHandle.cs +++ b/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpFloat32BackendHandle.cs @@ -7,9 +7,8 @@ For full license see LICENSE in the root directory of this project. */ using System; -using System.Linq; +using DiffSharp.Backend; using Microsoft.FSharp.Core; -using Sigma.Core.Utils; using static DiffSharp.Util; namespace Sigma.Core.Handlers.Backends.SigmaDiff @@ -22,18 +21,20 @@ public unsafe class DiffSharpFloat32BackendHandle : DiffSharpBackendHandle /// Create a DiffSharpFloat32BackendHandle with a certain BLAS and LAPACK backend and an associated handle tag. /// - /// - /// - /// + /// The BLAS backend to use (must use 32-bit floats). + /// The LAPACK backend to use (must use 32-bit floats). + /// The backend tag to use. public DiffSharpFloat32BackendHandle(IBlasBackend blasBackend, ILapackBackend lapackBackend, long backendTag) : base(blasBackend, lapackBackend, backendTag) { } + /// public override ISigmaDiffDataBuffer CreateDataBuffer(float[] values) { return new SigmaDiffDataBuffer(values, backendTag: BackendTag); } + /// public override float L1Norm_V(ISigmaDiffDataBuffer value) { if (value.Length == 0) @@ -50,6 +51,7 @@ public override float L1Norm_V(ISigmaDiffDataBuffer value) } } + /// public override float L2Norm_V(ISigmaDiffDataBuffer value) { if (value.Length == 0) @@ -66,6 +68,7 @@ public override float L2Norm_V(ISigmaDiffDataBuffer value) } } + /// public override float SupNorm_V(ISigmaDiffDataBuffer value) { if (value.Length == 0) @@ -84,6 +87,7 @@ public override float SupNorm_V(ISigmaDiffDataBuffer value) } } + /// public override float Sum_V(ISigmaDiffDataBuffer value) { if (value.Length == 0) @@ -102,11 +106,13 @@ public override float Sum_V(ISigmaDiffDataBuffer value) return sum; } + /// public override float Sum_M(ISigmaDiffDataBuffer value) { return Sum_V(value); } + /// public override ISigmaDiffDataBuffer Add_V_V(ISigmaDiffDataBuffer a, ISigmaDiffDataBuffer b) { if (a.Length == 0) @@ -132,6 +138,7 @@ public override ISigmaDiffDataBuffer Add_V_V(ISigmaDiffDataBuffer return b; } + /// public override ISigmaDiffDataBuffer Add_S_V(float a, ISigmaDiffDataBuffer b) { if (b.Length == 0) @@ -152,6 +159,7 @@ public override ISigmaDiffDataBuffer Add_S_V(float a, ISigmaDiffDataBuffe return b; } + /// public override ISigmaDiffDataBuffer Sub_V_V(ISigmaDiffDataBuffer a, ISigmaDiffDataBuffer b) { if (a.Length == 0) @@ -177,6 +185,7 @@ public override ISigmaDiffDataBuffer Sub_V_V(ISigmaDiffDataBuffer return b; } + /// public override ISigmaDiffDataBuffer Sub_S_V(float a, ISigmaDiffDataBuffer b) { if (b.Length == 0) @@ -197,6 +206,7 @@ public override ISigmaDiffDataBuffer Sub_S_V(float a, ISigmaDiffDataBuffe return b; } + /// public override ISigmaDiffDataBuffer Sub_V_S(ISigmaDiffDataBuffer a, float b) { if (a.Length == 0) @@ -217,6 +227,7 @@ public override ISigmaDiffDataBuffer Sub_V_S(ISigmaDiffDataBuffer return a; } + /// public override ISigmaDiffDataBuffer Mul_S_V(float a, ISigmaDiffDataBuffer b) { if (b.Length == 0) @@ -236,6 +247,7 @@ public override ISigmaDiffDataBuffer Mul_S_V(float a, ISigmaDiffDataBuffe return b; } + /// public override ISigmaDiffDataBuffer Mul_M_V(ShapedDataBufferView a, ISigmaDiffDataBuffer b) { if (a.Length * b.Length == 0) @@ -260,16 +272,19 @@ public override ISigmaDiffDataBuffer Mul_M_V(ShapedDataBufferView return z; } + /// public override ISigmaDiffDataBuffer Mul_M_V_Add_V(ShapedDataBufferView a, ISigmaDiffDataBuffer b, ISigmaDiffDataBuffer obj2) { throw new NotImplementedException(); } + /// public override float Mul_Dot_V_V(ISigmaDiffDataBuffer a, ISigmaDiffDataBuffer n) { throw new NotImplementedException(); } + /// public override ISigmaDiffDataBuffer Mul_V_M(ISigmaDiffDataBuffer a, ShapedDataBufferView b) { if (a.Length * b.Length == 0) @@ -294,22 +309,26 @@ public override ISigmaDiffDataBuffer Mul_V_M(ISigmaDiffDataBuffer return z; } + /// public override FSharpOption> Solve_M_V(ShapedDataBufferView a, ISigmaDiffDataBuffer b) { throw new NotImplementedException(); } + /// public override FSharpOption> SolveSymmetric_M_V(ShapedDataBufferView a, ISigmaDiffDataBuffer b) { throw new NotImplementedException(); } + /// public override ISigmaDiffDataBuffer Diagonal_M(ShapedDataBufferView a) { throw new NotImplementedException(); } - public override ISigmaDiffDataBuffer Map_F_V(FSharpFunc a, ISigmaDiffDataBuffer b) + /// + public override ISigmaDiffDataBuffer Map_F_V(MapOp mapOp, FSharpFunc a, ISigmaDiffDataBuffer b) { if (b.Length == 0) { @@ -327,27 +346,35 @@ public override ISigmaDiffDataBuffer Map_F_V(FSharpFunc a, return b; } - public override ISigmaDiffDataBuffer Map2_F_V_V(FSharpFunc> f, ISigmaDiffDataBuffer a, ISigmaDiffDataBuffer b) + /// + public override ISigmaDiffDataBuffer Map_F_S_V(float other, MapOp mapOp, FSharpFunc function, ISigmaDiffDataBuffer value) + { + return Map_F_V(mapOp, function, value); + } + + /// + public override ISigmaDiffDataBuffer Map2_F_V_V(MapOp mapOp, FSharpFunc> function, ISigmaDiffDataBuffer a, ISigmaDiffDataBuffer b) { if (a.Length == 0) { - return Map2_F_V_V(f, CreateDataBuffer(new float[b.Length]), b); + return Map2_F_V_V(mapOp, function, CreateDataBuffer(new float[b.Length]), b); } if (b.Length == 0) { - return Map2_F_V_V(f, a, CreateDataBuffer(new float[a.Length])); + return Map2_F_V_V(mapOp, function, a, CreateDataBuffer(new float[a.Length])); } b = b.DeepCopy(); for (int i = 0; i < a.Length; i++) { - b.Data[i] = f.Invoke(a.Data[i + a.Offset]).Invoke(b.Data[i + b.Offset]); + b.Data[i] = function.Invoke(a.Data[i + a.Offset]).Invoke(b.Data[i + b.Offset]); } return b; } + /// public override ShapedDataBufferView Mul_Out_V_V(ISigmaDiffDataBuffer a, ISigmaDiffDataBuffer b) { if (a.Length * b.Length == 0) @@ -372,6 +399,7 @@ public override ShapedDataBufferView Mul_Out_V_V(ISigmaDiffDataBuffer(z, m, n); } + /// public override ShapedDataBufferView Add_M_M(ShapedDataBufferView a, ShapedDataBufferView b) { if (a.Length == 0) @@ -397,6 +425,7 @@ public override ShapedDataBufferView Add_M_M(ShapedDataBufferView return b; } + /// public override ShapedDataBufferView Add_S_M(float a, ShapedDataBufferView b) { if (b.Length == 0) @@ -417,11 +446,13 @@ public override ShapedDataBufferView Add_S_M(float a, ShapedDataBufferVie return b; } + /// public override ShapedDataBufferView Add_V_MCols(ISigmaDiffDataBuffer a, ShapedDataBufferView b) { throw new NotImplementedException(); } + /// public override ShapedDataBufferView Sub_M_M(ShapedDataBufferView a, ShapedDataBufferView b) { if (a.Length == 0) @@ -447,6 +478,7 @@ public override ShapedDataBufferView Sub_M_M(ShapedDataBufferView return a; } + /// public override ShapedDataBufferView Sub_M_S(ShapedDataBufferView a, float b) { if (a.Length == 0) @@ -467,6 +499,7 @@ public override ShapedDataBufferView Sub_M_S(ShapedDataBufferView return a; } + /// public override ShapedDataBufferView Sub_S_M(float a, ShapedDataBufferView b) { if (b.Length == 0) @@ -499,6 +532,7 @@ public override ShapedDataBufferView Sub_S_M(float a, ShapedDataBufferVie return b; } + /// public override ShapedDataBufferView Mul_M_M(ShapedDataBufferView a, ShapedDataBufferView b) { if (a.Length * b.Length == 0) @@ -522,6 +556,7 @@ public override ShapedDataBufferView Mul_M_M(ShapedDataBufferView return new ShapedDataBufferView(z, a.Rows, b.Cols); } + /// public override ShapedDataBufferView Mul_S_M(float a, ShapedDataBufferView b) { if (b.Length == 0) @@ -541,11 +576,13 @@ public override ShapedDataBufferView Mul_S_M(float a, ShapedDataBufferVie return b; } + /// public override ShapedDataBufferView Mul_M_M_Add_V_MCols(ShapedDataBufferView a, ShapedDataBufferView b, ISigmaDiffDataBuffer c) { throw new NotImplementedException(); } + /// public override ShapedDataBufferView Mul_Had_M_M(ShapedDataBufferView a, ShapedDataBufferView b) { if (a.Length == 0) @@ -571,11 +608,13 @@ public override ShapedDataBufferView Mul_Had_M_M(ShapedDataBufferView public override FSharpOption> Inverse_M(ShapedDataBufferView a) { throw new NotImplementedException(); } + /// public override FSharpOption Det_M(ShapedDataBufferView a) { if (a.Length == 0) @@ -611,6 +650,7 @@ public override FSharpOption Det_M(ShapedDataBufferView a) return FSharpOption.Some(det); } + /// public override ShapedDataBufferView Transpose_M(ShapedDataBufferView a) { if (a.Length == 0) @@ -628,13 +668,64 @@ public override ShapedDataBufferView Transpose_M(ShapedDataBufferView Map_F_M(FSharpFunc f, ShapedDataBufferView a) + private bool _InternalOptimisedMapOp_F_M(MapOp mapOp, ref ShapedDataBufferView a) + { + if (mapOp.IsExp) + { + a = a.DeepCopy(); + int upper = a.DataBuffer.Offset + a.DataBuffer.Length; + for (int i = a.DataBuffer.Offset; i < upper; i++) + { + a.DataBuffer.Data[i] = (float) Math.Exp(a.DataBuffer.Data[i]); + } + + return true; + } + else if (mapOp.IsSqrt) + { + a = a.DeepCopy(); + int upper = a.DataBuffer.Offset + a.DataBuffer.Length; + for (int i = a.DataBuffer.Offset; i < upper; i++) + { + a.DataBuffer.Data[i] = (float) Math.Sqrt(a.DataBuffer.Data[i]); + } + + return true; + } + + return false; + } + + private bool _InternalOptimisedMapOp_F_S_M(float other, MapOp mapOp, ref ShapedDataBufferView a) + { + if (mapOp.IsDiv) + { + a = a.DeepCopy(); + int upper = a.DataBuffer.Offset + a.DataBuffer.Length; + for (int i = a.DataBuffer.Offset; i < upper; i++) + { + a.DataBuffer.Data[i] = other / a.DataBuffer.Data[i]; + } + + return true; + } + + return false; + } + + /// + public override ShapedDataBufferView Map_F_M(MapOp mapOp, FSharpFunc f, ShapedDataBufferView a) { if (a.Length == 0) { return new ShapedDataBufferView(CreateDataBuffer(new float[0]), 0L, 0L); } + if (_InternalOptimisedMapOp_F_M(mapOp, ref a)) + { + return a; + } + a = a.DeepCopy(); int upper = a.DataBuffer.Offset + a.DataBuffer.Length; @@ -646,7 +737,19 @@ public override ShapedDataBufferView Map_F_M(FSharpFunc f, return a; } - public override ShapedDataBufferView Map2_F_M_M(FSharpFunc> f, ShapedDataBufferView a, ShapedDataBufferView b) + /// + public override ShapedDataBufferView Map_F_S_M(float other, MapOp mapOp, FSharpFunc function, ShapedDataBufferView value) + { + if (_InternalOptimisedMapOp_F_S_M(other, mapOp, ref value)) + { + return value; + } + + return Map_F_M(mapOp, function, value); + } + + /// + public override ShapedDataBufferView Map2_F_M_M(MapOp mapOp, FSharpFunc> f, ShapedDataBufferView a, ShapedDataBufferView b) { if (a.Length == 0) { @@ -654,7 +757,6 @@ public override ShapedDataBufferView Map2_F_M_M(FSharpFunc Map2_F_M_M(FSharpFunc public override ISigmaDiffDataBuffer ReshapeCopy_MRows_V(ShapedDataBufferView value) { if (value.Length == 0) @@ -673,6 +776,7 @@ public override ISigmaDiffDataBuffer ReshapeCopy_MRows_V(ShapedDataBuffer return value.DataBuffer.DeepCopy(); } + /// public override ShapedDataBufferView ReshapeCopy_V_MRows(int rows, ISigmaDiffDataBuffer value) { if (value.Length == 0) @@ -685,11 +789,13 @@ public override ShapedDataBufferView ReshapeCopy_V_MRows(int rows, ISigma return new ShapedDataBufferView(value.DeepCopy(), rows, n); } + /// public override ShapedDataBufferView RepeatReshapeCopy_V_MRows(int rows, ISigmaDiffDataBuffer value) { throw new NotImplementedException(); } + /// public override ShapedDataBufferView RepeatReshapeCopy_V_MCols(int cols, ISigmaDiffDataBuffer value) { throw new NotImplementedException(); diff --git a/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpFloat32Handler.cs b/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpFloat32Handler.cs index 5f168dc0..90263f63 100644 --- a/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpFloat32Handler.cs +++ b/Sigma.Core/Handlers/Backends/SigmaDiff/DiffSharpFloat32Handler.cs @@ -35,9 +35,9 @@ public abstract class DiffSharpFloat32Handler : IComputationHandler, ISerialisat internal DiffSharpBackendHandle DiffsharpBackendHandle { - get { return _diffsharpBackendHandle; } - private set { _diffsharpBackendHandle = value; } - } + get { return _diffsharpBackendHandle; } + private set { _diffsharpBackendHandle = value; } + } [NonSerialized] private readonly ILog _logger = LogManager.GetLogger(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType); @@ -117,22 +117,39 @@ protected ADNDFloat32Array AssignTag(ADNDFloat32Array array) return array; } - // IComputationHandler stuff that is probably different for each diffsharp handler implementation + // IComputationHandler stuff that is probably different for each diffsharp handler implementation + /// public abstract void InitAfterDeserialisation(INDArray array); + /// public abstract long GetSizeBytes(params INDArray[] array); + /// public abstract bool IsInterchangeable(IComputationHandler otherHandler); + /// public abstract INDArray NDArray(params long[] shape); + /// public abstract INDArray NDArray(TOther[] values, params long[] shape); + /// public abstract INumber Number(object value); + /// public abstract IDataBuffer DataBuffer(T[] values); + /// public abstract INDArray AsNDArray(INumber number); + /// public abstract INumber AsNumber(INDArray array, params long[] indices); + /// public abstract bool CanConvert(INDArray array, IComputationHandler otherHandler); + /// public abstract INDArray Convert(INDArray array, IComputationHandler otherHandler); + /// public abstract void Fill(INDArray filler, INDArray arrayToFill); - public abstract void Fill(TOther value, INDArray arrayToFill); - - protected ADNDFloat32Array ConvertInternal(INDArray array) + /// + public abstract void Fill(TOther value, INDArray arrayToFill); + /// + public abstract void Fill(INDArray filler, INDArray arrayToFill, long[] sourceBeginIndices, long[] sourceEndIndices, long[] destinationBeginIndices, long[] destinationEndIndices); + /// + public abstract void Fill(T[] filler, INDArray arrayToFill, long[] destinationBeginIndices, long[] destinationEndIndices); + + protected ADNDFloat32Array ConvertInternal(INDArray array) { return new ADNDFloat32Array(_backendTag, array.GetDataAs(), array.Shape); } diff --git a/Sigma.Core/Handlers/Backends/SigmaDiff/NativeCpu/CPUFloat32Handler.cs b/Sigma.Core/Handlers/Backends/SigmaDiff/NativeCpu/CPUFloat32Handler.cs index d73c6f7e..fc5871ca 100644 --- a/Sigma.Core/Handlers/Backends/SigmaDiff/NativeCpu/CPUFloat32Handler.cs +++ b/Sigma.Core/Handlers/Backends/SigmaDiff/NativeCpu/CPUFloat32Handler.cs @@ -12,6 +12,7 @@ For full license see LICENSE in the root directory of this project. using System; using Sigma.Core.MathAbstract.Backends.SigmaDiff; using Sigma.Core.MathAbstract.Backends.SigmaDiff.NativeCpu; +using Sigma.Core.Utils; namespace Sigma.Core.Handlers.Backends.SigmaDiff.NativeCpu { @@ -21,22 +22,27 @@ namespace Sigma.Core.Handlers.Backends.SigmaDiff.NativeCpu [Serializable] public class CpuFloat32Handler : DiffSharpFloat32Handler { + /// public CpuFloat32Handler() : base(new OpenBlasBlasBackend(), new OpenBlasLapackBackend()) { } + /// public override IDataType DataType => DataTypes.Float32; + /// public override IDataBuffer DataBuffer(T[] values) { return new SigmaDiffDataBuffer(values, backendTag: DiffsharpBackendHandle.BackendTag); } + /// public override INDArray NDArray(params long[] shape) { return AssignTag(new ADNDFloat32Array(DiffsharpBackendHandle.BackendTag, shape)).SetAssociatedHandler(this); } + /// public override INDArray NDArray(TOther[] values, params long[] shape) { float[] convertedValues = new float[values.Length]; @@ -50,11 +56,13 @@ public override INDArray NDArray(TOther[] values, params long[] shape) return AssignTag(new ADNDFloat32Array(DiffsharpBackendHandle.BackendTag, convertedValues, shape)).SetAssociatedHandler(this); } + /// public override INumber Number(object value) { return new ADFloat32Number((float) System.Convert.ChangeType(value, typeof(float))).SetAssociatedHandler(this); } + /// public override INDArray AsNDArray(INumber number) { ADFloat32Number internalNumber = InternaliseNumber(number); @@ -62,6 +70,7 @@ public override INDArray AsNDArray(INumber number) return AssignTag(new ADNDFloat32Array(DNDArray.OfDNumber(internalNumber._adNumberHandle, DiffsharpBackendHandle))); } + /// public override INumber AsNumber(INDArray array, params long[] indices) { ADNDFloat32Array internalArray = InternaliseArray(array); @@ -70,12 +79,14 @@ public override INumber AsNumber(INDArray array, params long[] indices) return new ADFloat32Number(DNDArray.ToDNumber(internalArray._adArrayHandle, (int) flatIndex)); } + /// public override void InitAfterDeserialisation(INDArray array) { // nothing to do here for this handler, all relevant components are serialised automatically, // diffsharp does not need to be de-serialised, components only need to be removed from trace } + /// public override long GetSizeBytes(params INDArray[] arrays) { long totalSizeBytes = 0L; @@ -93,34 +104,39 @@ public override long GetSizeBytes(params INDArray[] arrays) return totalSizeBytes; } + /// public override bool IsInterchangeable(IComputationHandler otherHandler) { //there are no interchangeable implementations so it will have to be the same type return otherHandler.GetType() == GetType(); } + /// public override bool CanConvert(INDArray array, IComputationHandler otherHandler) { //if it's the same base unit and at least the same precision we can convert return otherHandler.DataType.BaseUnderlyingType == DataType.BaseUnderlyingType && otherHandler.DataType.SizeBytes >= DataType.SizeBytes; } + /// public override INDArray Convert(INDArray array, IComputationHandler otherHandler) { return ConvertInternal(array); } + /// public override void Fill(INDArray filler, INDArray arrayToFill) { - IDataBuffer arrayToFillData = ((ADNDArray) arrayToFill).Data; - IDataBuffer fillerData = ((ADNDArray) filler).Data; + IDataBuffer arrayToFillData = InternaliseArray(arrayToFill).Data; + IDataBuffer fillerData = InternaliseArray(filler).Data; arrayToFillData.SetValues(fillerData.Data, fillerData.Offset, arrayToFillData.Offset, Math.Min(arrayToFill.Length, filler.Length)); } - public override void Fill(TOther value, INDArray arrayToFill) + /// + public override void Fill(TOther value, INDArray arrayToFill) { - IDataBuffer arrayToFillData = ((ADNDArray) arrayToFill).Data; + IDataBuffer arrayToFillData = InternaliseArray(arrayToFill).Data; float floatValue = (float) System.Convert.ChangeType(value, typeof(float)); @@ -129,5 +145,36 @@ public override void Fill(TOther value, INDArray arrayToFill) arrayToFillData.Data.SetValue(floatValue, i); } } - } + + /// + public override void Fill(INDArray filler, INDArray arrayToFill, long[] sourceBeginIndices, long[] sourceEndIndices, long[] destinationBeginIndices, long[] destinationEndIndices) + { + IDataBuffer fillerData = InternaliseArray(filler).Data; + IDataBuffer arrayToFillData = InternaliseArray(arrayToFill).Data; + + int sourceOffset = (int) NDArrayUtils.GetFlatIndex(filler.Shape, filler.Strides, sourceBeginIndices); + int sourceLength = (int) NDArrayUtils.GetFlatIndex(filler.Shape, filler.Strides, sourceEndIndices) - sourceOffset + 1; // +1 because end is inclusive + int destinationOffset = (int) NDArrayUtils.GetFlatIndex(arrayToFill.Shape, arrayToFill.Strides, destinationBeginIndices); + int destinationLength = (int) NDArrayUtils.GetFlatIndex(arrayToFill.Shape, arrayToFill.Strides, destinationEndIndices) - destinationOffset + 1; // same here + + if (sourceLength < 0) throw new ArgumentOutOfRangeException($"Source begin indices must be smaller than source end indices, but source length was {sourceLength}."); + if (destinationLength < 0) throw new ArgumentOutOfRangeException($"Destination begin indices must be smaller than destination end indices, but destination length was {destinationLength}."); + if (sourceLength != destinationLength) throw new ArgumentException($"Source and destination indices length must batch, but source length was {sourceLength} and destination legnth was {destinationLength}."); + + Array.Copy(fillerData.Data, sourceOffset, arrayToFillData.Data, destinationOffset, sourceLength); + } + + /// + public override void Fill(T[] filler, INDArray arrayToFill, long[] destinationBeginIndices, long[] destinationEndIndices) + { + IDataBuffer arrayToFillData = InternaliseArray(arrayToFill).Data; + + int destinationOffset = (int) NDArrayUtils.GetFlatIndex(arrayToFill.Shape, arrayToFill.Strides, destinationBeginIndices); + int destinationLength = (int) NDArrayUtils.GetFlatIndex(arrayToFill.Shape, arrayToFill.Strides, destinationEndIndices) - destinationOffset + 1; // +1 because end is inclusive + + if (destinationLength < 0) throw new ArgumentOutOfRangeException($"Destination begin indices must be smaller than destination end indices, but destination length was {destinationLength}."); + + Array.Copy(filler, 0, arrayToFillData.Data, destinationOffset, destinationLength); + } + } } diff --git a/Sigma.Core/Handlers/IComputationHandler.cs b/Sigma.Core/Handlers/IComputationHandler.cs index 5c752a21..bcea2dca 100644 --- a/Sigma.Core/Handlers/IComputationHandler.cs +++ b/Sigma.Core/Handlers/IComputationHandler.cs @@ -133,6 +133,28 @@ public interface IComputationHandler /// The ndarray to fill. void Fill(INDArray filler, INDArray arrayToFill); + /// + /// Fill an ndarray with the contents of another ndarray within a specific range. + /// Note: The index ranges must be of the same size (in source and destination). + /// + /// The filler ndarray (from which the values will be copied in the specified range). + /// The array to fill within the specified range. + /// The begin indices in the filler array. + /// The end indices in the filler array. + /// The begin indices in the array to fill. + /// The end indices in the array to fill. + void Fill(INDArray filler, INDArray arrayToFill, long[] sourceBeginIndices, long[] sourceEndIndices, long[] destinationBeginIndices, long[] destinationEndIndices); + + /// + /// Fill an ndarray with the contents of another ndarray within a specific range. + /// Note: The index ranges must be of the same size (in source and destination). + /// + /// The filler ndarray (from which the values will be copied in the specified range). + /// The array to fill within the specified range. + /// The begin indices in the array to fill. + /// The end indices in the array to fill. + void Fill(T[] filler, INDArray arrayToFill, long[] destinationBeginIndices, long[] destinationEndIndices); + /// /// Fill an ndarray with a single value. /// This is not a traceable operation. diff --git a/Sigma.Core/Layers/BaseLayer.cs b/Sigma.Core/Layers/BaseLayer.cs index 22f0bcd0..58171b1b 100644 --- a/Sigma.Core/Layers/BaseLayer.cs +++ b/Sigma.Core/Layers/BaseLayer.cs @@ -15,6 +15,7 @@ namespace Sigma.Core.Layers /// /// A basic base layer to simplify custom layer implementations of the ILayer interface. /// + [Serializable] public abstract class BaseLayer : ILayer { public string Name { get; } diff --git a/Sigma.Core/Layers/Cost/BaseCostLayer.cs b/Sigma.Core/Layers/Cost/BaseCostLayer.cs index 96506071..c1ed7ab3 100644 --- a/Sigma.Core/Layers/Cost/BaseCostLayer.cs +++ b/Sigma.Core/Layers/Cost/BaseCostLayer.cs @@ -6,6 +6,7 @@ MIT License For full license see LICENSE in the root directory of this project. */ +using System; using Sigma.Core.Architecture; using Sigma.Core.Handlers; using Sigma.Core.MathAbstract; @@ -16,6 +17,7 @@ namespace Sigma.Core.Layers.Cost /// /// A base cost layer that takes of getting the predictions and targets sorted out to calculate the cost more easily. /// + [Serializable] public abstract class BaseCostLayer : BaseLayer { /// diff --git a/Sigma.Core/Layers/Cost/SoftMaxCrossEntropyCostLayer.cs b/Sigma.Core/Layers/Cost/SoftMaxCrossEntropyCostLayer.cs index 7687f41e..4d872fcf 100644 --- a/Sigma.Core/Layers/Cost/SoftMaxCrossEntropyCostLayer.cs +++ b/Sigma.Core/Layers/Cost/SoftMaxCrossEntropyCostLayer.cs @@ -36,7 +36,7 @@ protected override INumber CalculateCost(INDArray predictions, INDArray targets, INDArray a = handler.Multiply(targets, logPredictions); INDArray inverseTargets = handler.Subtract(1, targets); - INDArray inversePredictions = handler.Subtract(1, predictions); + INDArray inversePredictions = handler.Subtract(1 + 1e-6, predictions); INDArray b = handler.Multiply(inverseTargets, handler.Log(inversePredictions)); INumber cost = handler.Divide(handler.Sum(handler.Add(a, b)), -predictions.Shape[0]); diff --git a/Sigma.Core/MathAbstract/Backends/SigmaDiff/ADNDArray.cs b/Sigma.Core/MathAbstract/Backends/SigmaDiff/ADNDArray.cs index bea8e5cc..ddb17646 100644 --- a/Sigma.Core/MathAbstract/Backends/SigmaDiff/ADNDArray.cs +++ b/Sigma.Core/MathAbstract/Backends/SigmaDiff/ADNDArray.cs @@ -117,7 +117,8 @@ public ADNDArray(params long[] shape) Data = new DataBuffer(Length); } - public virtual object DeepCopy() + /// + public virtual object DeepCopy() { return new ADNDArray((IDataBuffer) Data.DeepCopy(), (long[]) Shape.Clone()).SetAssociatedHandler(AssociatedHandler); } diff --git a/Sigma.Core/MathAbstract/Backends/SigmaDiff/ArrayPool.cs b/Sigma.Core/MathAbstract/Backends/SigmaDiff/ArrayPool.cs new file mode 100644 index 00000000..3c31a8f3 --- /dev/null +++ b/Sigma.Core/MathAbstract/Backends/SigmaDiff/ArrayPool.cs @@ -0,0 +1,72 @@ +/* +MIT License + +Copyright (c) 2016-2017 Florian Cäsar, Michael Plainer + +For full license see LICENSE in the root directory of this project. +*/ + +using System.Collections.Generic; +using Sigma.Core.Utils; + +namespace Sigma.Core.MathAbstract.Backends.SigmaDiff +{ + /// + /// An array pool for pooling arrays. + /// Duh. + /// + /// + public class ArrayPool + { + private readonly IDictionary> _availableArrays; + + /// + /// Create a new array pool. + /// + public ArrayPool() + { + _availableArrays = new Dictionary>(); + } + + /// + /// Allocate an array of a certain size from this array pool. + /// + /// The array size. + /// An array of the given size. + public T[] Allocate(int arraySize) + { + if (!_availableArrays.ContainsKey(arraySize)) + { + return new T[arraySize]; + } + + IList pooledArrays = _availableArrays[arraySize]; + int lastIndex = pooledArrays.Count - 1; + + T[] lastPooledArray = pooledArrays[lastIndex]; + + pooledArrays.RemoveAt(lastIndex); + + return lastPooledArray; + } + + /// + /// Free a specific array allocated with this + /// + /// + public void Free(T[] array) + { + // TODO what happens if the same array is freed multiple times? check list? but that's slower than checking a set... but a set doesn't have an item order... + _availableArrays.TryGetValue(array.Length, () => new List()).Add(array); + } + + /// + /// Free all pooled arrays. + /// + public void FreeAll() + { + _availableArrays.Clear(); + // now get to work GC + } + } +} diff --git a/Sigma.Core/Monitors/Synchronisation/ISynchronisationHandler.cs b/Sigma.Core/Monitors/Synchronisation/ISynchronisationHandler.cs index 76b37fed..92e6ba1c 100644 --- a/Sigma.Core/Monitors/Synchronisation/ISynchronisationHandler.cs +++ b/Sigma.Core/Monitors/Synchronisation/ISynchronisationHandler.cs @@ -21,6 +21,26 @@ public interface ISynchronisationHandler /// SigmaEnvironment Sigma { get; } + /// + /// Add an additional source that will be used if the values cannot be found. + /// + /// The source that will be added. + void AddSynchronisationSource(ISynchronisationSource source); + + /// + /// Remove an additional source. + /// + /// The source that will be removed. + /// True if teh source could be removed, false otherwise. + bool RemoveSynchronisationSource(ISynchronisationSource source); + + /// + /// Check if a source is contained. + /// + /// The source that will be checked. + /// True, if the handler contains the source - false otherwise + bool ContainsSynchronisationSoruce(ISynchronisationSource source); + /// /// Indicate that a value has changed and synchronise it with the given . /// @@ -41,5 +61,15 @@ public interface ISynchronisationHandler /// The type of the value that will be gathered. /// The fully resolved identifier for the parameter that will be received. T SynchroniseGet(IRegistry registry, string key); + + /// + /// Update a value with a given action if it has changed (not ). + /// + /// The type of the value that will be gathered. + /// The registry in which the entry will be set. + /// The fully resolved identifier for the parameter that will be received. + /// The current value of the object. + /// The method that will be called if the parameter has to be updated. + void SynchroniseUpdate(IRegistry registry, string key, T currentVal,Action update); } } \ No newline at end of file diff --git a/Sigma.Core/Monitors/Synchronisation/ISynchronisationSource.cs b/Sigma.Core/Monitors/Synchronisation/ISynchronisationSource.cs new file mode 100644 index 00000000..f6f17e9b --- /dev/null +++ b/Sigma.Core/Monitors/Synchronisation/ISynchronisationSource.cs @@ -0,0 +1,42 @@ +using System.Collections.Generic; + +namespace Sigma.Core.Monitors.Synchronisation +{ + /// + /// A synchronisation source provides additional data for a synchronisation handler if a value cannot be found in the default registry. + /// + public interface ISynchronisationSource + { + /// + /// Try to retrieve a value from this source (if existent). + /// + /// The type of the value that will be retrieved. + /// The key of the value. + /// The value itself that will be assigned if it could be retrieved, null otherwise. + /// True if the source could retrieve given key, false otherwise. + bool TryGet(string key, out T val); + + /// + /// Try to set a value from this source (if existent). + /// + /// The type of the value that will be set. + /// The key of the value. + /// The value itself that will be assigned if it applicable. + /// True if the source could set given key, false otherwise. + bool TrySet(string key, T val); + + /// + /// Determine whether a given key is contained / manged by this source. + /// + /// The key that will be checked. + /// True if given key can be accessed with get / set, false otherwise. + bool Contains(string key); + + /// + /// This is a list of keys this source provides. It is completely optional, although it is recommended to implement it. + /// + /// Once a new source is added, the keys of the sources are checked against to determine double entries which makes debugging for users easier (as log entries are produced autoamtically). + /// + string[] Keys { get; } + } +} \ No newline at end of file diff --git a/Sigma.Core/Monitors/Synchronisation/SetValueCommand.cs b/Sigma.Core/Monitors/Synchronisation/SetValueCommand.cs index 5e59f5dc..cbad4a4d 100644 --- a/Sigma.Core/Monitors/Synchronisation/SetValueCommand.cs +++ b/Sigma.Core/Monitors/Synchronisation/SetValueCommand.cs @@ -110,6 +110,7 @@ public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) for (int i = 0; i < keys.Length; i++) { + //TODO: validate if successfully set and call error otherwise (for each key?) resolver.ResolveSet(keys[i], values[i], AddItentifierIfNotExists, typeof(T)); } } diff --git a/Sigma.Core/Monitors/Synchronisation/SynchronisationHandler.cs b/Sigma.Core/Monitors/Synchronisation/SynchronisationHandler.cs index 7da376a8..86c44038 100644 --- a/Sigma.Core/Monitors/Synchronisation/SynchronisationHandler.cs +++ b/Sigma.Core/Monitors/Synchronisation/SynchronisationHandler.cs @@ -8,6 +8,8 @@ For full license see LICENSE in the root directory of this project. using System; using System.Collections.Generic; +using System.Linq; +using log4net; using Sigma.Core.Training.Operators; using Sigma.Core.Utils; @@ -17,18 +19,25 @@ namespace Sigma.Core.Monitors.Synchronisation /// The default synchronisation handler for Sigma. It is responsible for syncing values between monitors /// and the environment itself. /// - public class SynchronisationHandler : ISynchronisationHandler + public class SynchronisationHandler : ISynchronisationHandler, ISynchronisationSource { /// /// The environment this handler is associated with. /// public SigmaEnvironment Sigma { get; } + private readonly ILog _logger = LogManager.GetLogger(typeof(SigmaEnvironment)); + /// /// Map every registry to a resolver for that registry. /// protected Dictionary RegistryResolvers { get; } + /// + /// The assigned synchronisation sources. + /// + protected List Sources; + /// /// Default constructor for . /// @@ -39,10 +48,54 @@ public SynchronisationHandler(SigmaEnvironment sigma) if (sigma == null) throw new ArgumentNullException(nameof(sigma)); RegistryResolvers = new Dictionary(); + Sources = new List(); Sigma = sigma; } + /// + public void AddSynchronisationSource(ISynchronisationSource source) + { + if (source == null) throw new ArgumentNullException(nameof(source)); + + if (source.Keys != null) + { + foreach (string key in source.Keys) + { + foreach (ISynchronisationSource savedSource in Sources) + { + if (savedSource.Keys != null) + { + if (savedSource.Keys.Contains(key)) + { + _logger.Warn($"The key {key} is added to a synchronisation handler by {source.GetType()} but is already provided by {savedSource.GetType()}. It is uncertain which key will be taken (possible performance decrease)."); + } + } + } + } + } + + Sources.Add(source); + } + + /// + public bool RemoveSynchronisationSource(ISynchronisationSource source) + { + if (source == null) throw new ArgumentNullException(nameof(source)); + return Sources.Remove(source); + } + + + /// + /// Check if a source is contained. + /// + /// The source that will be checked. + /// True, if the handler contains the source - false otherwise + public bool ContainsSynchronisationSoruce(ISynchronisationSource source) + { + return Sources.Contains(source); + } + /// public virtual void SynchroniseSet(IRegistry registry, string key, T val, Action onSuccess = null, Action onError = null) { @@ -52,6 +105,7 @@ public virtual void SynchroniseSet(IRegistry registry, string key, T val, Act if (ReferenceEquals(op.Registry, registry)) { //TODO: test if callback is called + //TODO: on error check sources for other to set the value op.InvokeCommand(new SetValueCommand(key, val, () => onSuccess?.Invoke(val))); return; @@ -74,8 +128,71 @@ public virtual void SynchroniseSet(IRegistry registry, string key, T val, Act /// public virtual T SynchroniseGet(IRegistry registry, string key) { - IRegistryResolver resolver = RegistryResolvers.TryGetValue(registry, () => new RegistryResolver(registry)); - return resolver.ResolveGetSingleWithDefault(key, default(T)); + if (registry != null) + { + IRegistryResolver resolver = RegistryResolvers.TryGetValue(registry, () => new RegistryResolver(registry)); + //return resolver.ResolveGetSingle<>() + string[] emptyArrayThrowaway; + + T[] result = resolver.ResolveGet(key, out emptyArrayThrowaway); + + if (result.Length != 0) + { + return result[0]; + } + } + + foreach (ISynchronisationSource source in Sources) + { + T res; + if (source.TryGet(key, out res)) + { + return res; + } + } + + return default(T); + } + + /// + /// Update a value with a given action if it has changed (). + /// + /// The type of the value that will be gathered. + /// The registry in which the entry will be set. + /// The fully resolved identifier for the parameter that will be received. + /// The current value of the object. + /// The method that will be called if the parameter has to be updated. + public void SynchroniseUpdate(IRegistry registry, string key, T currentVal, Action update) + { + if (update == null) throw new ArgumentNullException(nameof(update)); + + T newObj = SynchroniseGet(registry, key); + if (newObj != null && currentVal == null || newObj != null && !newObj.Equals(currentVal)) + { + update(newObj); + } } + + public string[] Keys { get; } + + /// + bool ISynchronisationSource.TryGet(string key, out T val) + { + throw new NotImplementedException("Get currently not implemented as no registry is passed"); + } + + + /// + bool ISynchronisationSource.Contains(string key) + { + throw new NotImplementedException("Get currently not implemented as no registry is passed"); + } + + /// + bool ISynchronisationSource.TrySet(string key, T val) + { + throw new NotImplementedException("Set currently not implemented as there is no error"); + } + } } \ No newline at end of file diff --git a/Sigma.Core/Parameterisation/ParameterisationManager.cs b/Sigma.Core/Parameterisation/ParameterisationManager.cs index 9f267cad..ef7c16e6 100644 --- a/Sigma.Core/Parameterisation/ParameterisationManager.cs +++ b/Sigma.Core/Parameterisation/ParameterisationManager.cs @@ -16,6 +16,7 @@ namespace Sigma.Core.Parameterisation /// /// A parameterisation /// + [Serializable] public class ParameterisationManager : IParameterisationManager { private readonly IDictionary _identifierMappings; diff --git a/Sigma.Core/Persistence/Serialisation.cs b/Sigma.Core/Persistence/Serialisation.cs index bf8ebe59..432c9430 100644 --- a/Sigma.Core/Persistence/Serialisation.cs +++ b/Sigma.Core/Persistence/Serialisation.cs @@ -14,6 +14,7 @@ For full license see LICENSE in the root directory of this project. using System.Runtime.Serialization; using log4net; using log4net.Core; +using Sigma.Core.Parameterisation; using Sigma.Core.Utils; namespace Sigma.Core.Persistence @@ -36,7 +37,7 @@ public static class Serialisation /// Optionally indicate where the log messages should written to (verbose = Info, otherwise Debug). public static void WriteBinaryFile(object obj, string filename, bool verbose = true) { - Write(obj, Target.FileByName(filename), Serialisers.BinarySerialiser, verbose); + Write(obj, Target.FileByName(filename), Serialisers.BinarySerialiser, verbose: verbose); } /// @@ -63,7 +64,8 @@ public static T ReadBinaryFile(string filename, bool verbose = true) /// The serialiser. /// Optionally indicate if the stream should be automatically closed. /// Optionally indicate where the log messages should written to (verbose = Info, otherwise Debug). - public static void Write(object obj, Stream target, ISerialiser serialiser, bool autoClose = true, bool verbose = true) + /// The number of bytes written (if exposed by the used target stream). + public static long Write(object obj, Stream target, ISerialiser serialiser, bool autoClose = true, bool verbose = true) { if (obj == null) throw new ArgumentNullException(nameof(obj)); if (target == null) throw new ArgumentNullException(nameof(target)); @@ -86,6 +88,8 @@ public static void Write(object obj, Stream target, ISerialiser serialiser, bool LoggingUtils.Log(verbose ? Level.Info : Level.Debug, $"Done writing {obj.GetType().Name} {obj} to target stream {target} using serialiser {serialiser}, " + $"wrote {(bytesWritten / 1024.0):#.#}kB, took {stopwatch.ElapsedMilliseconds}ms.", ClazzLogger); + + return bytesWritten; } /// @@ -118,7 +122,7 @@ public static T Read(Stream target, ISerialiser serialiser, bool verbose = tr // automatically restore all logger instances if (field.FieldType == typeof(ILog)) { - field.SetValue(parent, LogManager.GetLogger(parent.GetType())); + field.SetValue(parent, LogManager.GetLogger(Assembly.GetCallingAssembly(), parent.GetType().Namespace + "." + parent.GetType().Name)); } (obj as ISerialisationNotifier)?.OnDeserialised(); @@ -130,6 +134,38 @@ public static T Read(Stream target, ISerialiser serialiser, bool verbose = tr return (T) read; } + /// + /// Attempt to read and validate certain object from a binary file, return the original value if unsuccessful. + /// + /// The object type. + /// The file name. + /// The original value. + /// Optionally indicate where the log messages should written to (verbose = Info, otherwise Debug). + /// The optional validation function to validate the read object with (if false, the original value is returned). + /// The read (i.e. existing) if successfully read and validated, otherwise the original value. + public static T ReadBinaryFileIfExists(string fileName, T originalValue, bool verbose = true, Func validationFunction = null) + { + try + { + T existing = ReadBinaryFile(fileName, verbose); + + if (validationFunction == null || validationFunction.Invoke(existing)) + { + LoggingUtils.Log(verbose ? Level.Info : Level.Debug, $"Read and validation of type {typeof(T)} successful, returning existing value.", ClazzLogger); + + return existing; + } + + LoggingUtils.Log(verbose ? Level.Info : Level.Debug, $"Read of type {typeof(T)} successful, validation failed, returning default value.", ClazzLogger); + } + catch (Exception e) + { + LoggingUtils.Log(verbose ? Level.Info : Level.Debug, $"Read of type {typeof(T)} failed with {e}, returning default value.", ClazzLogger); + } + + return originalValue; + } + /// /// Traverse the object graph of a given object (i.e. all related and referenced objects, recursively). /// @@ -139,13 +175,14 @@ public static T Read(Stream target, ISerialiser serialiser, bool verbose = tr internal static void TraverseObjectGraph(object root, ISet traversedObjects, Action action) { Type type = root.GetType(); + traversedObjects.Add(root); // traverse all types up to object base type for all relevant fields in the graph do { FieldInfo[] fields = type.GetFields(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public); - + // hierarchy change listeners.. // for every type check all fields for relevance foreach (FieldInfo field in fields) { @@ -158,11 +195,11 @@ internal static void TraverseObjectGraph(object root, ISet traversedObje if (value != null) { - // TODO smarter optimisation than checking for system namespace, maybe build cache with type information? + // Note: I am completely aware how awful this "optimisation" is, but it works and there currently is no time to implement a better system. string ns = value.GetType().Namespace; - bool boringSystemType = ns.StartsWith("System") && !ns.StartsWith("System.Collections"); + bool boringType = ns.StartsWith("System") && !ns.StartsWith("System.Collections") || ns.StartsWith("log4net"); - if (!boringSystemType && !traversedObjects.Contains(value)) + if (!boringType && !traversedObjects.Contains(value) && !Attribute.IsDefined(field, typeof(NonSerializedAttribute))) { TraverseObjectGraph(value, traversedObjects, action); diff --git a/Sigma.Core/Sigma.Core.csproj b/Sigma.Core/Sigma.Core.csproj index 704ebe09..ecd5b8d2 100644 --- a/Sigma.Core/Sigma.Core.csproj +++ b/Sigma.Core/Sigma.Core.csproj @@ -138,10 +138,11 @@ - + + @@ -160,6 +161,7 @@ + @@ -195,6 +197,7 @@ + @@ -207,6 +210,7 @@ + @@ -231,6 +235,7 @@ + @@ -242,6 +247,8 @@ + + @@ -289,6 +296,7 @@ + @@ -313,7 +321,6 @@ - diff --git a/Sigma.Core/Sigma.cs b/Sigma.Core/Sigma.cs index 211c771a..9a39c0a4 100644 --- a/Sigma.Core/Sigma.cs +++ b/Sigma.Core/Sigma.cs @@ -59,7 +59,7 @@ public class SigmaEnvironment : ISerialisationNotifier private ManualResetEvent _processQueueEvent; [NonSerialized] - private readonly ILog _logger = LogManager.GetLogger(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType); + private readonly ILog _logger = LogManager.GetLogger(typeof(SigmaEnvironment)); /// /// The unique name of this environment. @@ -86,6 +86,7 @@ public class SigmaEnvironment : ISerialisationNotifier /// public IParameterisationManager ParameterisationManager { get; } + //TODO: put into trainer /// /// The handler that is responsible for syncing the monitor with operators / workers. /// @@ -329,6 +330,16 @@ public async Task RunAsync() await Task.Run(() => Run()); } + /// + /// Prepare and run this sigma environment. + /// Note: This should only be used with smaller projects, as an early call gives monitors more time to setup and more instant user feedback. + /// + public void PrepareAndRun() + { + Prepare(); + Run(); + } + private void InitialiseTrainers() { foreach (ITrainer trainer in _trainersByName.Values) diff --git a/Sigma.Core/Training/Hooks/BaseHook.cs b/Sigma.Core/Training/Hooks/BaseHook.cs index 70958af3..253e4aeb 100644 --- a/Sigma.Core/Training/Hooks/BaseHook.cs +++ b/Sigma.Core/Training/Hooks/BaseHook.cs @@ -57,7 +57,7 @@ public abstract class BaseHook : IHook /// Set this to true if performance intensive operations (e.g. storing to disk, processing large arrays) are used in this hook. /// Note: When invoked in background, hooks received a complete copy of all required registry entries and can therefore not directly modify the parameters of a worker/operator. /// - public bool InvokeInBackground { get; protected set; } = false; + public bool InvokeInBackground { get; protected set; } /// /// The operator that owns this hook and dispatched it for execution. @@ -114,6 +114,17 @@ protected BaseHook(ITimeStep timestep, ISet requiredRegistryEntries) ParameterRegistry = new Registry(); } + /// + /// Set this hook to be invoked in a background thread. + /// + /// + public BaseHook SetInvokeInBackground() + { + InvokeInBackground = true; + + return this; + } + /// /// Invoke this hook only when a certain hook invoke criteria is satisfied. /// diff --git a/Sigma.Core/Training/Hooks/HookInvokeCriteria.cs b/Sigma.Core/Training/Hooks/HookInvokeCriteria.cs index 7f4a5f39..2914febd 100644 --- a/Sigma.Core/Training/Hooks/HookInvokeCriteria.cs +++ b/Sigma.Core/Training/Hooks/HookInvokeCriteria.cs @@ -306,6 +306,7 @@ public override string ToString() /// /// A threshold criteria that fires when a certain threshold is reached (once or continuously as specified). /// + [Serializable] public class ThresholdCriteria : HookInvokeCriteria { /// @@ -369,6 +370,7 @@ public override string ToString() /// /// An extrema criteria that fires when a value has reached a new extrema (min / max). /// + [Serializable] public class ExtremaCriteria : HookInvokeCriteria { /// diff --git a/Sigma.Core/Training/Hooks/LambdaHook.cs b/Sigma.Core/Training/Hooks/LambdaHook.cs index 13ab9d78..55991209 100644 --- a/Sigma.Core/Training/Hooks/LambdaHook.cs +++ b/Sigma.Core/Training/Hooks/LambdaHook.cs @@ -35,7 +35,7 @@ public LambdaHook(ITimeStep timestep, Action invok /// A helper resolver for complex registry entries (automatically cached). public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) { - var action = ParameterRegistry.Get>("invoke_action"); + Action action = ParameterRegistry.Get>("invoke_action"); action.Invoke(registry, resolver); } } diff --git a/Sigma.Core/Training/Hooks/Processors/MetricProcessorHook.cs b/Sigma.Core/Training/Hooks/Processors/MetricProcessorHook.cs new file mode 100644 index 00000000..c24f3f91 --- /dev/null +++ b/Sigma.Core/Training/Hooks/Processors/MetricProcessorHook.cs @@ -0,0 +1,85 @@ +/* +MIT License + +Copyright (c) 2016-2017 Florian Cäsar, Michael Plainer + +For full license see LICENSE in the root directory of this project. +*/ + +using System; +using System.Collections.Generic; +using Sigma.Core.Handlers; +using Sigma.Core.MathAbstract; +using Sigma.Core.Utils; + +namespace Sigma.Core.Training.Hooks.Processors +{ + [Serializable] + public class MetricProcessorHook : BaseHook where T : class + { + public MetricProcessorHook(string registryEntryToProcess, Func metricFunction, string metricSharedResultEntry) : this(Utils.TimeStep.Every(1, TimeScale.Iteration), registryEntryToProcess, metricFunction, metricSharedResultEntry) + { + } + + public MetricProcessorHook(ITimeStep timestep, string registryEntryToProcess, Func metricFunction, string metricSharedResultIdentifier) : base(timestep, registryEntryToProcess) + { + if (registryEntryToProcess == null) throw new ArgumentNullException(nameof(registryEntryToProcess)); + if (metricFunction == null) throw new ArgumentNullException(nameof(metricFunction)); + if (metricSharedResultIdentifier == null) throw new ArgumentNullException(nameof(metricSharedResultIdentifier)); + + InvokePriority = -1000; + ParameterRegistry["registry_entry_to_process"] = registryEntryToProcess; + ParameterRegistry["metric_function"] = metricFunction; + ParameterRegistry["metric_shared_result_identifier"] = metricSharedResultIdentifier; + } + + /// + public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) + { + IComputationHandler handler = Operator.Handler; + + string registryEntryToProcess = ParameterRegistry.Get("registry_entry_to_process"); + Func metricFunction = ParameterRegistry.Get>("metric_function"); + string metricSharedResultIdentifier = ParameterRegistry.Get("metric_shared_result_identifier"); + + object[] entries = resolver.ResolveGet(registryEntryToProcess); + + double totalMetric = 0.0; + int count = 0; + + foreach (object entry in entries) + { + T entryAsT = entry as T; + IEnumerable entryAsEnumerable = entry as IEnumerable; + IDictionary entryAsDictionary = entry as IDictionary; + + if (entryAsDictionary != null) + { + entryAsEnumerable = entryAsDictionary.Values; + } + + if (entryAsT != null) + { + totalMetric += metricFunction.Invoke(entryAsT, handler).GetValueAs(); + count++; + } + else if (entryAsEnumerable != null) + { + foreach (T value in entryAsEnumerable) + { + totalMetric += metricFunction.Invoke(value, handler).GetValueAs(); + count++; + } + } + else + { + throw new InvalidOperationException($"Cannot process metric for entry of type {entry.GetType()} with identifier \"{registryEntryToProcess}\", must be {typeof(T)} or enumerable thereof."); + } + } + + double resultMetric = totalMetric / count; + + resolver.ResolveSet(metricSharedResultIdentifier, resultMetric, addIdentifierIfNotExists: true); + } + } +} diff --git a/Sigma.Core/Training/Hooks/Processors/NumberAccumulatorHook.cs b/Sigma.Core/Training/Hooks/Processors/NumberAccumulatorHook.cs index dbe445c9..a33eafc1 100644 --- a/Sigma.Core/Training/Hooks/Processors/NumberAccumulatorHook.cs +++ b/Sigma.Core/Training/Hooks/Processors/NumberAccumulatorHook.cs @@ -17,16 +17,19 @@ namespace Sigma.Core.Training.Hooks.Accumulators [Serializable] public class NumberAccumulatorHook : BaseHook { - public NumberAccumulatorHook(string registryEntry, TimeStep timeStep, int resetEvery = -1, int resetInterval = 0) : this(registryEntry, registryEntry.Replace('.', '_') + "_accumulated", timeStep, resetEvery, resetInterval) + public NumberAccumulatorHook(string registryEntry, TimeStep timeStep, bool averageMode = false, int resetEvery = -1, int resetInterval = 0) : this(registryEntry, registryEntry.Replace('.', '_') + "_accumulated", timeStep, averageMode, resetEvery, resetInterval) { } - public NumberAccumulatorHook(string registryEntry, string resultEntry, TimeStep timeStep, int resetEvery = -1, int resetInterval = 0) : base(timeStep, registryEntry) + public NumberAccumulatorHook(string registryEntry, string resultEntry, TimeStep timeStep, bool averageMode = false, int resetEvery = -1, int resetInterval = 0) : base(timeStep, registryEntry) { ParameterRegistry["registry_entry"] = registryEntry; ParameterRegistry["shared_result_entry"] = resultEntry; + ParameterRegistry["accumulated_value"] = 0.0; ParameterRegistry["reset_interval"] = resetInterval; ParameterRegistry["reset_every"] = resetEvery; + ParameterRegistry["average_mode"] = averageMode; + ParameterRegistry["count_since_reset"] = 0; } /// @@ -40,18 +43,31 @@ public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) string resultEntry = ParameterRegistry.Get("shared_result_entry"); double value = resolver.ResolveGetSingle(registryEntry); - double accumulatedValue = resolver.ResolveGetSingleWithDefault(resultEntry, 0.0); + double previousAccumulatedValue = ParameterRegistry.Get("accumulated_value"); int currentInterval = HookUtils.GetCurrentInterval(registry, TimeStep.TimeScale); int resetInterval = ParameterRegistry.Get("reset_interval"); int resetEvery = ParameterRegistry.Get("reset_every"); + int countSinceReset = ParameterRegistry.Get("count_since_reset"); if (currentInterval == resetInterval || resetEvery > 0 && currentInterval % resetEvery == 0) { - accumulatedValue = 0.0; + previousAccumulatedValue = 0.0; + countSinceReset = 0; } - resolver.ResolveSet(resultEntry, value + accumulatedValue, addIdentifierIfNotExists: true); + countSinceReset++; + + double result = value + previousAccumulatedValue; + + if (ParameterRegistry.Get("average_mode")) + { + result /= countSinceReset; + } + + ParameterRegistry["count_since_reset"] = countSinceReset; + ParameterRegistry["accumulated_value"] = value + previousAccumulatedValue; + resolver.ResolveSet(resultEntry, result, addIdentifierIfNotExists: true); } } } diff --git a/Sigma.Core/Training/Hooks/Reporters/RunningTimeReporter.cs b/Sigma.Core/Training/Hooks/Reporters/RunningTimeReporter.cs index f25d8dd2..a77ec491 100644 --- a/Sigma.Core/Training/Hooks/Reporters/RunningTimeReporter.cs +++ b/Sigma.Core/Training/Hooks/Reporters/RunningTimeReporter.cs @@ -27,7 +27,7 @@ public class RunningTimeReporter : BaseHook /// /// The time step. /// The interval span to average over. - public RunningTimeReporter(TimeStep timeStep, int averageSpan = 4) : base(Utils.TimeStep.Every(1, timeStep.TimeScale)) + public RunningTimeReporter(ITimeStep timeStep, int averageSpan = 4) : base(timeStep) { DefaultTargetMode = TargetMode.Global; @@ -46,8 +46,8 @@ public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) { string baseResultKey = ParameterRegistry.Get("base_result_key"); - long lastTime = resolver.ResolveGetSingleWithDefault(baseResultKey + "_last", -1L); - long averageTime = resolver.ResolveGetSingleWithDefault(baseResultKey + "_average", -1L); + long lastTime = resolver.ResolveGetSingleWithDefault(baseResultKey + "_last", -1L); + long averageTime = resolver.ResolveGetSingleWithDefault(baseResultKey + "_average", -1L); Report(TimeStep.TimeScale, lastTime, averageTime); } diff --git a/Sigma.Core/Training/Hooks/Reporters/ValueReporterHook.cs b/Sigma.Core/Training/Hooks/Reporters/ValueReporterHook.cs index a532c516..d0a0eda5 100644 --- a/Sigma.Core/Training/Hooks/Reporters/ValueReporterHook.cs +++ b/Sigma.Core/Training/Hooks/Reporters/ValueReporterHook.cs @@ -29,7 +29,8 @@ public class ValueReporterHook : BaseHook /// /// The value that will be fetched (i.e. registry identifier). E.g. "optimiser.cost_total" /// The the hook will executed on. - public ValueReporterHook(string valueIdentifier, ITimeStep timestep) : this(new[] { valueIdentifier }, timestep) { } + /// Indicate whether or not to report the current epoch and iteration in addition to the values. + public ValueReporterHook(string valueIdentifier, ITimeStep timestep, bool reportEpochIteration = false) : this(new[] { valueIdentifier }, timestep, reportEpochIteration: reportEpochIteration) { } /// /// Create a hook that conditionally (extrema criteria) fetches a given value (i.e. registry identifier) at a given . @@ -55,12 +56,13 @@ public ValueReporterHook(string valueIdentifier, ITimeStep timestep, double thre On(new ThresholdCriteria(GetAccumulatedIdentifier(valueIdentifier), target, threshold, fireContinously)); } - /// - /// Create a hook that fetches a given amount of values (i.e. registry identifiers) at a given . - /// - /// The values that will be fetched (i.e. registry identifiers). E.g. "optimiser.cost_total", ... - /// The the hook will executed on. - public ValueReporterHook(string[] valueIdentifiers, ITimeStep timestep) : base(timestep, valueIdentifiers) + /// + /// Create a hook that fetches a given amount of values (i.e. registry identifiers) at a given . + /// + /// The values that will be fetched (i.e. registry identifiers). E.g. "optimiser.cost_total", ... + /// The the hook will executed on. + /// Indicate whether or not to report the current epoch and iteration in addition to the values. + public ValueReporterHook(string[] valueIdentifiers, ITimeStep timestep, bool averageValues = false, bool reportEpochIteration = false) : base(timestep, valueIdentifiers) { if (valueIdentifiers.Length == 0) throw new ArgumentException("Value identifiers cannot be empty (it's the whole point of this hook)."); @@ -90,7 +92,7 @@ public ValueReporterHook(string[] valueIdentifiers, ITimeStep timestep) : base(t resetInterval = 0; } - RequireHook(new NumberAccumulatorHook(value, accumulatedIdentifiers[i], Utils.TimeStep.Every(1, TimeScale.Iteration), resetEvery, resetInterval)); + RequireHook(new NumberAccumulatorHook(value, accumulatedIdentifiers[i], Utils.TimeStep.Every(1, TimeScale.Iteration), averageValues, resetEvery, resetInterval)); valueBuffer.Add(value, null); } @@ -98,6 +100,7 @@ public ValueReporterHook(string[] valueIdentifiers, ITimeStep timestep) : base(t ParameterRegistry["value_identifiers"] = valueIdentifiers; ParameterRegistry["accumulated_identifiers"] = accumulatedIdentifiers; ParameterRegistry["value_buffer"] = valueBuffer; + ParameterRegistry["report_epoch_iteration"] = reportEpochIteration; } private static string GetAccumulatedIdentifier(string value) @@ -125,7 +128,7 @@ public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) valuesByIdentifier[valueIdentifiers[i]] = value; } - ReportValues(valuesByIdentifier); + ReportValues(valuesByIdentifier, ParameterRegistry.Get("report_epoch_iteration"), registry.Get("epoch"), registry.Get("iteration")); } /// @@ -133,9 +136,14 @@ public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) /// Note: By default, this method writes to the logger. If you want to report to anywhere else, overwrite this method. /// /// The values by their identifier. - protected virtual void ReportValues(IDictionary valuesByIdentifier) + /// A boolean indicating whether or not to report the current epoch / iteration. + /// The current epoch. + /// The current iteration. + protected virtual void ReportValues(IDictionary valuesByIdentifier, bool reportEpochIteration, int epoch, int iteration) { - _logger.Info(string.Join(", ", valuesByIdentifier.Select(pair => $"{pair.Key} = {pair.Value}"))); + string formattedValues = string.Join(", ", valuesByIdentifier.Select(pair => $"{pair.Key} = {pair.Value}")); + + _logger.Info((reportEpochIteration ? $"epoch {epoch} / iteration {iteration}: " : "") + formattedValues); } } } diff --git a/Sigma.Core/Training/Hooks/Reporters/ValueSourceReporterHook.cs b/Sigma.Core/Training/Hooks/Reporters/ValueSourceReporterHook.cs new file mode 100644 index 00000000..aef22a04 --- /dev/null +++ b/Sigma.Core/Training/Hooks/Reporters/ValueSourceReporterHook.cs @@ -0,0 +1,154 @@ +using System.Collections.Generic; +using Sigma.Core.Utils; +using Sigma.Core.Monitors.Synchronisation; + +namespace Sigma.Core.Training.Hooks.Reporters +{ + /// + /// A hook that stores given values and can provide them to a as a source. + /// + public class ValueSourceReporterHook : BaseHook, ISynchronisationSource + { + private const string ValueIdentifier = "values"; + private const string RegistryResolver = "resolver"; + + //private readonly IDictionary _values = new Dictionary(); + + /// + /// Create a hook that fetches a given value (i.e. registry identifier) at a given . + /// + /// The value that will be fetched (i.e. registry identifier). E.g. "optimiser.cost_total" + /// The the hook will executed on. + public ValueSourceReporterHook(TimeStep timestep, string valueIdentifier) : base(timestep, valueIdentifier) + { + Initialise(valueIdentifier); + } + + + /// + /// Create a hook that fetches a given amount of values (i.e. registry identifiers) at a given . + /// + /// The values that will be fetched (i.e. registry identifiers). E.g. "optimiser.cost_total", ... + /// The the hook will executed on. + public ValueSourceReporterHook(ITimeStep timestep, params string[] valueIdentifiers) : base(timestep, valueIdentifiers) + { + Initialise(valueIdentifiers); + } + + private void Initialise() + { + IRegistry reg = new Registry(); + reg.Add(RegistryResolver, new RegistryResolver(reg)); + + ParameterRegistry.Add(ValueIdentifier, reg); + } + + /// + /// Initialise the dictionary containing the values with given . + /// + /// The value that will be fetched. + protected void Initialise(string valueIdentifier) + { + Initialise(); + IRegistry values = (IRegistry) ParameterRegistry[ValueIdentifier]; + values.Add(valueIdentifier, null); + Keys = new[] {valueIdentifier}; + } + + /// + /// Initialise the dictionary containing the values with given . + /// + /// The values that will be fetched. + protected void Initialise(string[] valueIdentifiers) + { + Initialise(); + IRegistry values = (IRegistry) ParameterRegistry[ValueIdentifier]; + + foreach (string identifier in valueIdentifiers) + { + values.Add(identifier, null); + } + Keys = valueIdentifiers; + } + + + /// + /// Try to retrieve a value from this source (if existent). + /// + /// The type of the value that will be retrieved. + /// The key of the value. + /// The value itself that will be assigned if it could be retrieved, null otherwise. + /// True if the source could retrieve given key, false otherwise. + public bool TryGet(string key, out T val) + { + IRegistry values = (IRegistry) ParameterRegistry[ValueIdentifier]; + IRegistryResolver resolver = values.Get(RegistryResolver); + + //TODO: validate lock requirement, probably it is required + lock (values) + { + T[] vals = resolver.ResolveGet(key); + + if (vals.Length > 0) + { + val = vals[0]; + return true; + } + + val = default(T); + return false; + } + } + + /// + /// No set supported in this observative hook. + /// + /// The type of the value that will be set. + /// The key of the value. + /// The value itself that will be assigned if it applicable. + /// True if the source could set given key, false otherwise. + public bool TrySet(string key, T val) + { + // a set is not supported + return false; + } + + /// + /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied. + /// + /// The registry containing the required values for this hook's execution. + /// A helper resolver for complex registry entries (automatically cached). + public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) + { + IDictionary values = (IDictionary) ParameterRegistry[ValueIdentifier]; + + //TODO: validate lock requirement, probably it is required + lock (values) + { + foreach (KeyValuePair valuePair in registry) + { + values[valuePair.Key] = valuePair.Value; + } + } + } + + /// + /// Determine whether a given key is contained / manged by this source. + /// + /// The key that will be checked. + /// True if given key can be accessed with get / set, false otherwise. + public bool Contains(string key) + { + IDictionary values = (IDictionary) ParameterRegistry[ValueIdentifier]; + + return values.ContainsKey(key); + } + + /// + /// This is a list of keys this source provides. It is completely optional, although it is recommended to implement it. + /// + /// Once a new source is added, the keys of the sources are checked against to determine double entries which makes debugging for users easier (as log entries are produced autoamtically). + /// + public string[] Keys { get; private set; } + } +} \ No newline at end of file diff --git a/Sigma.Core/Training/Hooks/Saviors/DiskSaviorHook.cs b/Sigma.Core/Training/Hooks/Saviors/DiskSaviorHook.cs new file mode 100644 index 00000000..2fafd58c --- /dev/null +++ b/Sigma.Core/Training/Hooks/Saviors/DiskSaviorHook.cs @@ -0,0 +1,118 @@ +/* +MIT License + +Copyright (c) 2016-2017 Florian Cäsar, Michael Plainer + +For full license see LICENSE in the root directory of this project. +*/ + +using System; +using System.Reflection; +using log4net; +using Sigma.Core.Persistence; +using Sigma.Core.Utils; + +namespace Sigma.Core.Training.Hooks.Saviors +{ + /// + /// A disk savior hook for selectively storing certain objects on disk on certain conditions / at certain intervals. + /// + /// The type of object to store. + [Serializable] + public class DiskSaviorHook : BaseHook + { + [NonSerialized] + private readonly ILog _logger = LogManager.GetLogger(Assembly.GetCallingAssembly(), typeof(DiskSaviorHook).Namespace + "." + typeof(DiskSaviorHook).Name); + + + /// + /// Create a savior hook that will automatically serialise a certain registry entry. + /// + /// + /// The file namer to store to disk as. + /// Indicate whether or not to report when the specified object was serialised. + public DiskSaviorHook(string registryEntryToSave, string fileName, bool verbose = true) : this(Utils.TimeStep.Every(1, TimeScale.Iteration), registryEntryToSave, fileName, verbose) + { + } + + /// + /// Create a savior hook that will automatically serialise a certain registry entry. + /// + /// The time step. + /// + /// The file namer to store to disk as. + /// Indicate whether or not to report when the specified object was serialised. + public DiskSaviorHook(ITimeStep timestep, string registryEntryToSave, string fileName, bool verbose = true) : this(timestep, registryEntryToSave, Namers.Static(fileName), verbose) + { + } + + /// + /// Create a savior hook that will automatically serialise a certain registry entry. + /// + /// + /// The file namer to store to disk as. + /// Indicate whether or not to report when the specified object was serialised. + public DiskSaviorHook(string registryEntryToSave, INamer fileNamer, bool verbose = true) : this(Utils.TimeStep.Every(1, TimeScale.Iteration), registryEntryToSave, fileNamer, verbose) + { + } + + /// + /// Create a savior hook that will automatically serialise a certain registry entry. + /// + /// The time step. + /// + /// The file namer to store to disk as. + /// Indicate whether or not to report when the specified object was serialised. + public DiskSaviorHook(ITimeStep timestep, string registryEntryToSave, INamer fileNamer, bool verbose = true) : this(timestep, registryEntryToSave, fileNamer, o => o, verbose) + { + } + + /// + /// Create a savior hook that will automatically serialise a certain registry entry. + /// + /// The time step. + /// + /// The file namer to store to disk as. + /// The select function to apply. + /// Indicate whether or not to report when the specified object was serialised. + public DiskSaviorHook(ITimeStep timestep, string registryEntryToSave, INamer fileNamer, Func selectFunction, bool verbose = true) : base(timestep, registryEntryToSave) + { + if (registryEntryToSave == null) throw new ArgumentNullException(nameof(registryEntryToSave)); + if (fileNamer == null) throw new ArgumentNullException(nameof(fileNamer)); + if (selectFunction == null) throw new ArgumentNullException(nameof(selectFunction)); + + ParameterRegistry["registry_entry_to_save"] = registryEntryToSave; + ParameterRegistry["file_namer"] = fileNamer; + ParameterRegistry["select_function"] = selectFunction; + ParameterRegistry["verbose"] = verbose; + + DefaultTargetMode = TargetMode.Global; + } + + /// + /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied. + /// + /// The registry containing the required values for this hook's execution. + /// A helper resolver for complex registry entries (automatically cached). + public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) + { + string registryEntryToSave = ParameterRegistry.Get("registry_entry_to_save"); + INamer fileNamer = ParameterRegistry.Get("file_namer"); + object toSerialise = resolver.ResolveGetSingle(registryEntryToSave); + bool verbose = ParameterRegistry.Get("verbose"); + Func selectFunction = ParameterRegistry.Get>("select_function"); + + toSerialise = selectFunction.Invoke((T) toSerialise); + + lock (fileNamer) + { + Serialisation.WriteBinaryFile(toSerialise, fileNamer.GetName(registry, resolver, this), verbose: false); + } + + if (verbose) + { + _logger.Info($"Saved \"{registryEntryToSave}\" to \"{SigmaEnvironment.Globals.Get("storage_path")}{fileNamer}\"."); + } + } + } +} diff --git a/Sigma.Core/Training/ITrainer.cs b/Sigma.Core/Training/ITrainer.cs index 418ba746..1c762373 100644 --- a/Sigma.Core/Training/ITrainer.cs +++ b/Sigma.Core/Training/ITrainer.cs @@ -99,7 +99,12 @@ public interface ITrainer IRegistry Registry { get; } /// - /// Add an initialiser by registry resolve string (e.g. FC*.weights, *.weights, Layer1.biases, Layer2.*). + /// Force (re-) initialisation of this trainer's , ignoring whether it has already been initialised or not. + /// + bool ForceInitialisation { get; set; } + + /// + /// Add an initialiser by registry resolve string (e.g. fc*.weights, *.weights, layer1.biases, layer2.*). /// Registry resolve notation may be used as the initialiser will be executed on all ndarrays which resolve to a match in a certain layer and match identifier. /// /// The identifier (registry resolve string). @@ -143,7 +148,8 @@ public interface ITrainer void AddLocalHook(IHook hook); /// - /// Initialise this trainer and the network to be trained using the set initialisers. Set up all handlers and constructs used to run the trainer. + /// Initialise this trainer and the network to be trained using the set initialisers (if the network is not already initialised or force initialisation is set). + /// Set up all handlers and constructs used to run the trainer. /// /// The computation handler to initialise for (must be the interchangeable with the one used for running the network). void Initialise(IComputationHandler handler); diff --git a/Sigma.Core/Training/Operators/Backends/NativeCpu/CpuMultithreadedOperator.cs b/Sigma.Core/Training/Operators/Backends/NativeCpu/CpuMultithreadedOperator.cs index 737ee21a..da5af8e2 100644 --- a/Sigma.Core/Training/Operators/Backends/NativeCpu/CpuMultithreadedOperator.cs +++ b/Sigma.Core/Training/Operators/Backends/NativeCpu/CpuMultithreadedOperator.cs @@ -81,7 +81,20 @@ public class CpuMultithreadedOperator : BaseOperator /// public ThreadPriority WorkerPriority { get; } - /// + /// + /// Create a new using the default (). + /// The will receive its default value (). + /// + public CpuMultithreadedOperator(double useAvailablePower = 0.5) : this(_InternalGetOptimalCpuWorkerCount(useAvailablePower)) + { + } + + private static int _InternalGetOptimalCpuWorkerCount(double useAvailablePower) + { + return (int) (Environment.ProcessorCount * useAvailablePower); + } + + /// /// Create a new using the default (). /// The will receive its default value (). /// diff --git a/Sigma.Core/Training/Operators/BaseOperator.cs b/Sigma.Core/Training/Operators/BaseOperator.cs index 5d954a9f..515f09ae 100644 --- a/Sigma.Core/Training/Operators/BaseOperator.cs +++ b/Sigma.Core/Training/Operators/BaseOperator.cs @@ -297,7 +297,7 @@ public void PushProgress(IWorker worker) localIterationNumbers[WorkerIndicesByWorkers[worker]] = worker.LocalIterationNumber; - if (localIterationNumbers.Any(i => i != worker.LocalIterationNumber)) + if (localIterationNumbers.Any(i => i < worker.LocalIterationNumber)) { allWorkersAtIteration = false; } @@ -1155,6 +1155,7 @@ protected void PopulateRegistry(IRegistry registry, INetwork localNetwork, IOpti registry["trainer"] = Trainer.Registry; registry["epoch"] = localEpochNumber; registry["iteration"] = localIterationNumber; + registry["runtime_millis"] = RunningTimeMilliseconds; if (!registry.ContainsKey("shared") || !(registry["shared"] is IRegistry)) { diff --git a/Sigma.Core/Training/Optimisers/Gradient/BaseGradientOptimiser.cs b/Sigma.Core/Training/Optimisers/Gradient/BaseGradientOptimiser.cs index 2158a004..9dc0a2cb 100644 --- a/Sigma.Core/Training/Optimisers/Gradient/BaseGradientOptimiser.cs +++ b/Sigma.Core/Training/Optimisers/Gradient/BaseGradientOptimiser.cs @@ -7,7 +7,8 @@ For full license see LICENSE in the root directory of this project. */ using System; -using log4net; +using System.Collections.Generic; +using System.Linq; using Sigma.Core.Architecture; using Sigma.Core.Handlers; using Sigma.Core.Layers; @@ -33,8 +34,6 @@ public abstract class BaseGradientOptimiser : IOptimiser /// protected readonly string ExternalCostAlias; - [NonSerialized] - private readonly ILog _logger = LogManager.GetLogger(System.Reflection.MethodBase.GetCurrentMethod().DeclaringType); private bool _prepared; private uint _traceTag; @@ -48,6 +47,7 @@ protected BaseGradientOptimiser(string externalCostAlias = "external_cost") ExternalCostAlias = externalCostAlias; Registry = new Registry(tags: "optimiser"); + Registry["updates"] = new Dictionary(); } /// @@ -133,9 +133,44 @@ public void Run(INetwork network, IComputationHandler handler) layerBuffer.Parameters[trainableParameter] = handler.ClearTrace(layerBuffer.Parameters.Get(trainableParameter)); } + + // outputs might have a trace as well, clear everything + _InternalClearAllTraces(layerBuffer.Inputs, handler); + _InternalClearAllTraces(layerBuffer.Outputs, handler); + } + } + + private static void _InternalClearAllTraces(IReadOnlyDictionary layerExternalBuffer, IComputationHandler handler) + { + foreach (string output in layerExternalBuffer.Keys.ToArray()) + { + IRegistry registry = layerExternalBuffer[output]; + + foreach (string parameter in registry.Keys.ToArray()) + { + ITraceable traceable = registry[parameter] as ITraceable; + + if (traceable != null) + { + registry[parameter] = handler.ClearTrace(traceable); + } + } } } + /// + /// Expose a parameter update to the outside through the gradient optimiser utilities. + /// + /// The parameter identifier. + /// The update. + protected void ExposeParameterUpdate(string parameterIdentifier, INDArray update) + { + if (parameterIdentifier == null) throw new ArgumentNullException(nameof(parameterIdentifier)); + if (update == null) throw new ArgumentNullException(nameof(update)); + + Registry.Get>("updates")[parameterIdentifier] = update; + } + /// /// Get the total cost from a certain network using a certain computation handler and put the relevant information in the cost registry (total, partial, importances). /// diff --git a/Sigma.Core/Training/Optimisers/Gradient/GradientDescentOptimiser.cs b/Sigma.Core/Training/Optimisers/Gradient/GradientDescentOptimiser.cs index b6573533..376f6a62 100644 --- a/Sigma.Core/Training/Optimisers/Gradient/GradientDescentOptimiser.cs +++ b/Sigma.Core/Training/Optimisers/Gradient/GradientDescentOptimiser.cs @@ -32,7 +32,11 @@ public GradientDescentOptimiser(double learningRate, string externalCostAlias = protected override INDArray Optimise(string paramIdentifier, INDArray parameter, INDArray gradient, IComputationHandler handler) { - return handler.Add(parameter, handler.Multiply(gradient, -Registry.Get("learning_rate"))); + INDArray update = handler.Multiply(gradient, -Registry.Get("learning_rate")); + + ExposeParameterUpdate(paramIdentifier, update); + + return handler.Add(parameter, update); } /// diff --git a/Sigma.Core/Training/Optimisers/Gradient/Memory/AdadeltaOptimiser.cs b/Sigma.Core/Training/Optimisers/Gradient/Memory/AdadeltaOptimiser.cs index 9783a10e..0c82fc98 100644 --- a/Sigma.Core/Training/Optimisers/Gradient/Memory/AdadeltaOptimiser.cs +++ b/Sigma.Core/Training/Optimisers/Gradient/Memory/AdadeltaOptimiser.cs @@ -60,6 +60,8 @@ protected override INDArray Optimise(string paramIdentifier, INDArray parameter, SetMemory(memoryIdentifierGradient, currentAccumulatedGradient); SetMemory(memoryIdentifierUpdate, currentAccumulatedUpdate); + ExposeParameterUpdate(paramIdentifier, update); + // compute optimised parameter using computed update return handler.Add(parameter, update); } diff --git a/Sigma.Core/Training/Optimisers/Gradient/Memory/AdagradOptimiser.cs b/Sigma.Core/Training/Optimisers/Gradient/Memory/AdagradOptimiser.cs index 28c8bd4b..fc304637 100644 --- a/Sigma.Core/Training/Optimisers/Gradient/Memory/AdagradOptimiser.cs +++ b/Sigma.Core/Training/Optimisers/Gradient/Memory/AdagradOptimiser.cs @@ -48,7 +48,11 @@ protected override INDArray Optimise(string paramIdentifier, INDArray parameter, INDArray adaptedLearningRate = handler.Divide(learningRate, handler.SquareRoot(handler.Add(squaredGradientSum, smoothing))); - return handler.Add(parameter, handler.Multiply(gradient, handler.Multiply(adaptedLearningRate, -1.0))); + INDArray update = handler.Multiply(gradient, handler.Multiply(adaptedLearningRate, -1.0)); + + ExposeParameterUpdate(paramIdentifier, update); + + return handler.Add(parameter, update); } /// diff --git a/Sigma.Core/Training/Optimisers/Gradient/Memory/MomentumGradientOptimiser.cs b/Sigma.Core/Training/Optimisers/Gradient/Memory/MomentumGradientOptimiser.cs index 4bdf364e..07e33390 100644 --- a/Sigma.Core/Training/Optimisers/Gradient/Memory/MomentumGradientOptimiser.cs +++ b/Sigma.Core/Training/Optimisers/Gradient/Memory/MomentumGradientOptimiser.cs @@ -44,7 +44,11 @@ protected override INDArray Optimise(string paramIdentifier, INDArray parameter, SetMemory(paramIdentifier, velocity); - return handler.Subtract(velocity, parameter); + INDArray update = handler.Multiply(velocity, -1.0); + + ExposeParameterUpdate(paramIdentifier, update); + + return handler.Add(parameter, update); } /// diff --git a/Sigma.Core/Training/Trainer.cs b/Sigma.Core/Training/Trainer.cs index f7b6db59..353184ae 100644 --- a/Sigma.Core/Training/Trainer.cs +++ b/Sigma.Core/Training/Trainer.cs @@ -27,386 +27,448 @@ For full license see LICENSE in the root directory of this project. namespace Sigma.Core.Training { - /// - /// The default implementation. - /// A trainer that defines how a network should be trained, denoting initialisers, optimiser, operator, custom hooks and data to apply and use. - /// - [Serializable] - public class Trainer : ITrainer - { - [NonSerialized] - private readonly ILog _logger = LogManager.GetLogger(MethodBase.GetCurrentMethod().DeclaringType); - private readonly IList _localHooks; - private readonly IList _globalHooks; - private readonly Dictionary _additionalNameDataIterators; - private readonly IList _allHooks; - private readonly IDictionary _initialisers; - private readonly IDictionary> _valueModifiers; - private bool _initialised; - - public string Name { get; } - public SigmaEnvironment Sigma { get; set; } - public INetwork Network { get; set; } - public IOptimiser Optimiser { get; set; } - public IOperator Operator { get; set; } = new CpuSinglethreadedOperator(); - public IDataProvider DataProvider { get; set; } = new DefaultDataProvider(); - - public IReadOnlyDictionary Initialisers { get; } - public IDataIterator TrainingDataIterator { get; set; } - public IReadOnlyDictionary AdditionalNameDataIterators { get; } - public IReadOnlyCollection Hooks { get; } - public IReadOnlyCollection GlobalHooks { get; } - public IReadOnlyCollection LocalHooks { get; } - public IReadOnlyDictionary> ValueModifiers { get; } - public IRegistry Registry { get; } - - public Trainer(string name) - { - if (name == null) throw new ArgumentNullException(nameof(name)); - - Name = name; - - _allHooks = new List(); - _localHooks = new List(); - _globalHooks = new List(); - _additionalNameDataIterators = new Dictionary(); - _initialisers = new Dictionary(); - _valueModifiers = new Dictionary>(); - - Hooks = new ReadOnlyCollection(_allHooks); - GlobalHooks = new ReadOnlyCollection(_globalHooks); - LocalHooks = new ReadOnlyCollection(_localHooks); - AdditionalNameDataIterators = new ReadOnlyDictionary(_additionalNameDataIterators); - ValueModifiers = new ReadOnlyDictionary>(_valueModifiers); - Initialisers = new ReadOnlyDictionary(_initialisers); - Registry = new Registry(tags: "trainer"); - Registry["self"] = this; - } - - public void AddNamedDataIterator(string name, IDataIterator iterator) - { - if (_additionalNameDataIterators.ContainsKey(name)) - { - throw new ArgumentException($"Named data iterator with name {name} already registered in this trainer ({Name})."); - } - - _additionalNameDataIterators.Add(name, iterator); - } - - public void AddInitialiser(string identifier, IInitialiser initialiser) - { - if (identifier == null) { throw new ArgumentNullException(nameof(identifier)); } - if (initialiser == null) { throw new ArgumentNullException(nameof(initialiser)); } - - if (_initialisers.ContainsKey(identifier)) - { - throw new InvalidOperationException($"Cannot add duplicate identifier {identifier} for initialiser {initialiser}," + - $" identifier is already bound to initialiser {_initialisers[identifier]}"); - } - - _initialisers.Add(identifier, initialiser); - } - - public void AddValueModifier(string identifier, IValueModifier modifier) - { - _valueModifiers.TryGetValue(identifier, () => new HashSet()).Add(modifier); - } - - public void AddHook(IHook hook) - { - if (hook.DefaultTargetMode == TargetMode.Local) - { - AddLocalHook(hook); - } - else if (hook.DefaultTargetMode == TargetMode.Global) - { - AddGlobalHook(hook); - } - else - { - throw new InvalidOperationException($"Ambiguous add hook call for hook {hook} with target mode {hook.DefaultTargetMode}. " + - $"Target mode must be explicitly {nameof(TargetMode.Local)} or {nameof(TargetMode.Global)} for implicit hook add to work" + - $" (i.e. unable to determine where to add this hook, specify it explicitly in the caller)."); - } - } - - public void AddLocalHook(IHook hook) - { - if (Hooks.Contains(hook)) - { - throw new ArgumentException($"Duplicate hook {hook}, hook already registered in this trainer ({Name})."); - } - - _allHooks.Add(hook); - _localHooks.Add(hook); - } - - public void AddGlobalHook(IHook hook) - { - if (Hooks.Contains(hook)) - { - throw new ArgumentException($"Duplicate hook {hook}, hook already registered in this trainer ({Name})."); - } - - _allHooks.Add(hook); - _globalHooks.Add(hook); - } - - public void Initialise(IComputationHandler handler) - { - ValidateAssignedComponents(); - - _logger.Info($"Initialising trainer \"{Name}\" with handler {handler}..."); - - ITaskObserver prepareTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Prepare, $"Preparing trainer {Name}"); - - Network.Initialise(handler); - - RegistryResolver networkResolver = new RegistryResolver(Network.Registry.Get("layers")); - int initialisedNDArrayCount = 0, initialisedNumberCount = 0; - List orderedInitialiserIdentifiers = _initialisers.Keys.ToList(); - orderedInitialiserIdentifiers.Sort(RegistryUtils.CompareIdentifierSpecificityAscending); - - foreach (string identifier in orderedInitialiserIdentifiers) - { - object[] values = networkResolver.ResolveGet(identifier, new object[0]); - IInitialiser initialiser = _initialisers[identifier]; - - foreach (object value in values) - { - INDArray array = value as INDArray; - - if (array != null) - { - initialiser.Initialise(array, handler, Sigma.Random); - initialisedNDArrayCount++; - } - else - { - INumber number = value as INumber; - - if (number != null) - { - initialiser.Initialise(number, handler, Sigma.Random); - initialisedNumberCount++; - } - } - } - } - - Operator.Sigma = Sigma; - Operator.Handler = Operator.Handler ?? handler; - Operator.Network = Network; - Operator.Trainer = this; - - // attach all given hooks - foreach (IHook hook in _globalHooks) - { - if (!Operator.AttachGlobalHook(hook)) - { - _logger.Debug($"Skipped attaching global hook {hook} in trainer \"{Name}\", operator refused to attach it."); - } - } - - foreach (IHook hook in _localHooks) - { - if (!Operator.AttachLocalHook(hook)) - { - _logger.Debug($"Skipped attaching local hook {hook} in trainer \"{Name}\", operator refused to attach it."); - } - } - - UpdateRegistry(); - - _initialised = true; - - SigmaEnvironment.TaskManager.EndTask(prepareTask); - - _logger.Info($"Done initialising trainer \"{Name}\" for handler {handler}, initialised {initialisedNDArrayCount} ndarrays and {initialisedNumberCount} numbers."); - } - - protected virtual void UpdateRegistry() - { - Registry["initialised"] = _initialised; - Registry["name"] = Name; - Registry["network"] = Network?.Registry; - Registry["optimiser"] = Optimiser?.Registry; - - Registry initialiserRegistry = new Registry(Registry, tags: "initialisers"); - Registry["initialisers"] = initialiserRegistry; - - foreach (string initialiserMatchIdentifier in Initialisers.Keys) - { - initialiserRegistry[initialiserMatchIdentifier] = Initialisers[initialiserMatchIdentifier]; - } - } - - private void ValidateAssignedComponents() - { - if (Network == null) - { - throw new InvalidOperationException($"Cannot initialise trainer {Name} before assigning a network."); - } - - if (Sigma == null) - { - throw new InvalidOperationException($"Cannot initialise trainer {Name} before assigning a sigma environment."); - } - - if (Operator == null) - { - throw new InvalidOperationException($"Cannot initialise trainer {Name} before assigning an operator."); - } - - if (DataProvider == null) - { - throw new InvalidOperationException($"Cannot initialise trainer {Name} before assigning a data provider."); - } - - Network.Validate(); - } - - private void CheckInitialised() - { - if (!_initialised) - { - throw new InvalidOperationException($"Trainer {Name} has not been initialised yet. Call {nameof(Initialise)}!"); - } - } - - public void Start() - { - _logger.Info($"Validating trainer state of trainer {Name} before start..."); - - ValidateAssignedComponents(); - CheckInitialised(); - - _logger.Info($"Starting operator {Operator} with trainer {Name}..."); - - Operator.Network = Network; - Operator.Trainer = this; - - Operator.Start(); - } - - public void RunTrainingIteration(INetwork localNetwork, IOptimiser localOptimiser, IRegistry localRegistry, IComputationHandler handler) - { - if (localNetwork == null) throw new ArgumentNullException(nameof(localNetwork)); - if (localOptimiser == null) throw new ArgumentNullException(nameof(localOptimiser)); - if (localRegistry == null) throw new ArgumentNullException(nameof(localRegistry)); - if (handler == null) throw new ArgumentNullException(nameof(handler)); - - CheckInitialised(); - - localOptimiser.PrepareRun(localNetwork, handler); - localNetwork.Run(handler, trainingPass: true); - localOptimiser.Run(localNetwork, handler); - ApplyValueModifiers(localRegistry, handler); - } - - private void ApplyValueModifiers(IRegistry localRegistry, IComputationHandler handler) - { - if (_valueModifiers.Count == 0) - { - return; - } - - RegistryResolver resolver = new RegistryResolver(localRegistry); - - foreach (string identifier in _valueModifiers.Keys) - { - string[] fullyResolvedIdentifiers; - object[] values = resolver.ResolveGet(identifier, out fullyResolvedIdentifiers); - - for (int i = 0; i < values.Length; i++) - { - object value = values[i]; - INDArray asNDArray = value as INDArray; - INumber asNumber = value as INumber; - - if (asNDArray != null) - { - foreach (IValueModifier modifier in _valueModifiers[identifier]) - { - asNDArray = modifier.Modify(fullyResolvedIdentifiers[i], asNDArray, handler); - } - values[i] = asNDArray; - } - else if (asNumber != null) - { - foreach (IValueModifier modifier in _valueModifiers[identifier]) - { - asNumber = modifier.Modify(fullyResolvedIdentifiers[i], asNumber, handler); - } - values[i] = asNumber; - } - else - { - double? asDouble = value as double?; - - if (asDouble != null) - { - foreach (IValueModifier modifier in _valueModifiers[identifier]) - { - asDouble = modifier.Modify(fullyResolvedIdentifiers[i], asDouble.Value, handler); - } - values[i] = asDouble.Value; - } - } - - resolver.ResolveSet(fullyResolvedIdentifiers[i], values[i]); - } - } - } - - /// - /// Provide the external data to a network given the current record block (typically as given by the training data iterator). - /// - /// The network to provide the data with. - /// The current record block. - public void ProvideExternalInputData(INetwork localNetwork, IDictionary currentBlock) - { - CheckInitialised(); - - DataProviderUtils.ProvideExternalInputData(DataProvider, localNetwork, currentBlock); - } - - /// - /// Provide the external output data from network to the data provider. - /// - /// The network to get the data from. - /// The current record block. - public void ProvideExternalOutputData(INetwork localNetwork, IDictionary currentBlock) - { - DataProviderUtils.ProvideExternalOutputData(DataProvider, localNetwork, currentBlock); - } - - /// - /// Reset this trainer to an un-initialised state, discard all progress information. If necessary, stop the operator (and wait for that). - /// - public void Reset() - { - _logger.Info($"Resetting trainer \"{Name}\" to un-initialised state, discarding all progress data..."); - - if (Operator?.State != ExecutionState.None) - { - _logger.Info($"Signalling operator to stop and reset, waiting for state change signal to continue trainer reset..."); - - Operator.SignalReset(); - Operator.WaitForStateChanged(); - } - - Network?.Reset(); - _initialised = false; - - UpdateRegistry(); - - _logger.Info($"Done resetting trainer \"{Name}\" to un-initialised state, discarded all progress data and stopped operator."); - } - - public override string ToString() - { - return $"trainer \"{Name}\""; - } - } + /// + /// The default implementation. + /// A trainer that defines how a network should be trained, denoting initialisers, optimiser, operator, custom hooks and data to apply and use. + /// + [Serializable] + public class Trainer : ITrainer + { + [NonSerialized] + private readonly ILog _logger = LogManager.GetLogger(MethodBase.GetCurrentMethod().DeclaringType); + private readonly IList _localHooks; + private readonly IList _globalHooks; + private readonly Dictionary _additionalNameDataIterators; + private readonly IList _allHooks; + private readonly IDictionary _initialisers; + private readonly IDictionary> _valueModifiers; + private bool _initialised; + + /// + public string Name { get; } + + /// + public SigmaEnvironment Sigma { get; set; } + + /// + public bool ForceInitialisation { get; set; } + + /// + public INetwork Network { get; set; } + + /// + public IOptimiser Optimiser { get; set; } + + /// + public IOperator Operator { get; set; } = new CpuSinglethreadedOperator(); + + /// + public IDataProvider DataProvider { get; set; } = new DefaultDataProvider(); + + /// + public IReadOnlyDictionary Initialisers { get; } + + /// + public IDataIterator TrainingDataIterator { get; set; } + + /// + public IReadOnlyDictionary AdditionalNameDataIterators { get; } + + /// + public IReadOnlyCollection Hooks { get; } + + /// + public IReadOnlyCollection GlobalHooks { get; } + + /// + public IReadOnlyCollection LocalHooks { get; } + + /// + public IReadOnlyDictionary> ValueModifiers { get; } + + /// + public IRegistry Registry { get; } + + /// + /// Create a trainer with a certain name. + /// + /// The name. + public Trainer(string name) + { + if (name == null) throw new ArgumentNullException(nameof(name)); + + Name = name; + + _allHooks = new List(); + _localHooks = new List(); + _globalHooks = new List(); + _additionalNameDataIterators = new Dictionary(); + _initialisers = new Dictionary(); + _valueModifiers = new Dictionary>(); + + Hooks = new ReadOnlyCollection(_allHooks); + GlobalHooks = new ReadOnlyCollection(_globalHooks); + LocalHooks = new ReadOnlyCollection(_localHooks); + AdditionalNameDataIterators = new ReadOnlyDictionary(_additionalNameDataIterators); + ValueModifiers = new ReadOnlyDictionary>(_valueModifiers); + Initialisers = new ReadOnlyDictionary(_initialisers); + Registry = new Registry(tags: "trainer"); + Registry["self"] = this; + } + + /// + public void AddNamedDataIterator(string name, IDataIterator iterator) + { + if (_additionalNameDataIterators.ContainsKey(name)) + { + throw new ArgumentException($"Named data iterator with name {name} already registered in this trainer ({Name})."); + } + + _additionalNameDataIterators.Add(name, iterator); + } + + /// + public void AddInitialiser(string identifier, IInitialiser initialiser) + { + if (identifier == null) { throw new ArgumentNullException(nameof(identifier)); } + if (initialiser == null) { throw new ArgumentNullException(nameof(initialiser)); } + + if (_initialisers.ContainsKey(identifier)) + { + throw new InvalidOperationException($"Cannot add duplicate identifier {identifier} for initialiser {initialiser}," + + $" identifier is already bound to initialiser {_initialisers[identifier]}"); + } + + _initialisers.Add(identifier, initialiser); + } + + /// + public void AddValueModifier(string identifier, IValueModifier modifier) + { + _valueModifiers.TryGetValue(identifier, () => new HashSet()).Add(modifier); + } + + /// + public void AddHook(IHook hook) + { + if (hook.DefaultTargetMode == TargetMode.Local) + { + AddLocalHook(hook); + } + else if (hook.DefaultTargetMode == TargetMode.Global) + { + AddGlobalHook(hook); + } + else + { + throw new InvalidOperationException($"Ambiguous add hook call for hook {hook} with target mode {hook.DefaultTargetMode}. " + + $"Target mode must be explicitly {nameof(TargetMode.Local)} or {nameof(TargetMode.Global)} for implicit hook add to work" + + $" (i.e. unable to determine where to add this hook, specify it explicitly in the caller)."); + } + } + + /// + public void AddLocalHook(IHook hook) + { + if (Hooks.Contains(hook)) + { + throw new ArgumentException($"Duplicate hook {hook}, hook already registered in this trainer ({Name})."); + } + + _allHooks.Add(hook); + _localHooks.Add(hook); + } + + /// + public void AddGlobalHook(IHook hook) + { + if (Hooks.Contains(hook)) + { + throw new ArgumentException($"Duplicate hook {hook}, hook already registered in this trainer ({Name})."); + } + + _allHooks.Add(hook); + _globalHooks.Add(hook); + } + + /// + public void Initialise(IComputationHandler handler) + { + ValidateAssignedComponents(); + + _logger.Info($"Initialising trainer \"{Name}\" with handler {handler}..."); + + ITaskObserver prepareTask = SigmaEnvironment.TaskManager.BeginTask(TaskType.Prepare, $"Preparing trainer {Name}"); + + int initialisedNDArrayCount; + int initialisedNumberCount; + + Network.AssociatedHandler = Operator.Handler; + + if (Network.Initialised && !ForceInitialisation) + { + initialisedNumberCount = 0; + initialisedNDArrayCount = 0; + + _logger.Info($"Skipping network initialisation because network was already initialised and force initialisation flag is set to false..."); + } + else + { + InitialiseNetwork(handler, out initialisedNumberCount, out initialisedNDArrayCount); + } + + Operator.Sigma = Sigma; + Operator.Handler = Operator.Handler ?? handler; + Operator.Network = Network; + Operator.Trainer = this; + + // attach all given hooks + foreach (IHook hook in _globalHooks) + { + if (!Operator.AttachGlobalHook(hook)) + { + _logger.Debug($"Skipped attaching global hook {hook} in trainer \"{Name}\", operator refused to attach it."); + } + } + + foreach (IHook hook in _localHooks) + { + if (!Operator.AttachLocalHook(hook)) + { + _logger.Debug($"Skipped attaching local hook {hook} in trainer \"{Name}\", operator refused to attach it."); + } + } + + UpdateRegistry(); + + _initialised = true; + + SigmaEnvironment.TaskManager.EndTask(prepareTask); + + _logger.Info($"Done initialising trainer \"{Name}\" for handler {handler}, initialised {initialisedNDArrayCount} ndarrays and {initialisedNumberCount} numbers."); + } + + private void InitialiseNetwork(IComputationHandler handler, out int initialisedNumberCount, out int initialisedNDArrayCount) + { + Network.Initialise(handler); + + initialisedNDArrayCount = 0; + initialisedNumberCount = 0; + + RegistryResolver networkResolver = new RegistryResolver(Network.Registry.Get("layers")); + List orderedInitialiserIdentifiers = _initialisers.Keys.ToList(); + orderedInitialiserIdentifiers.Sort(RegistryUtils.CompareIdentifierSpecificityAscending); + + foreach (string identifier in orderedInitialiserIdentifiers) + { + object[] values = networkResolver.ResolveGet(identifier, new object[0]); + IInitialiser initialiser = _initialisers[identifier]; + + foreach (object value in values) + { + INDArray array = value as INDArray; + + if (array != null) + { + initialiser.Initialise(array, handler, Sigma.Random); + initialisedNDArrayCount++; + } + else + { + INumber number = value as INumber; + + if (number != null) + { + initialiser.Initialise(number, handler, Sigma.Random); + initialisedNumberCount++; + } + } + } + } + } + + protected virtual void UpdateRegistry() + { + Registry["initialised"] = _initialised; + Registry["name"] = Name; + Registry["network"] = Network?.Registry; + Registry["optimiser"] = Optimiser?.Registry; + + Registry initialiserRegistry = new Registry(Registry, tags: "initialisers"); + Registry["initialisers"] = initialiserRegistry; + + foreach (string initialiserMatchIdentifier in Initialisers.Keys) + { + initialiserRegistry[initialiserMatchIdentifier] = Initialisers[initialiserMatchIdentifier]; + } + } + + private void ValidateAssignedComponents() + { + if (Network == null) + { + throw new InvalidOperationException($"Cannot initialise trainer {Name} before assigning a network."); + } + + if (Sigma == null) + { + throw new InvalidOperationException($"Cannot initialise trainer {Name} before assigning a sigma environment."); + } + + if (Operator == null) + { + throw new InvalidOperationException($"Cannot initialise trainer {Name} before assigning an operator."); + } + + if (DataProvider == null) + { + throw new InvalidOperationException($"Cannot initialise trainer {Name} before assigning a data provider."); + } + + Network.Validate(); + } + + private void CheckInitialised() + { + if (!_initialised) + { + throw new InvalidOperationException($"Trainer {Name} has not been initialised yet. Call {nameof(Initialise)}!"); + } + } + + public void Start() + { + _logger.Info($"Validating trainer state of trainer {Name} before start..."); + + ValidateAssignedComponents(); + CheckInitialised(); + + _logger.Info($"Starting operator {Operator} with trainer {Name}..."); + + Operator.Network = Network; + Operator.Trainer = this; + + Operator.Start(); + } + + public void RunTrainingIteration(INetwork localNetwork, IOptimiser localOptimiser, IRegistry localRegistry, IComputationHandler handler) + { + if (localNetwork == null) throw new ArgumentNullException(nameof(localNetwork)); + if (localOptimiser == null) throw new ArgumentNullException(nameof(localOptimiser)); + if (localRegistry == null) throw new ArgumentNullException(nameof(localRegistry)); + if (handler == null) throw new ArgumentNullException(nameof(handler)); + + CheckInitialised(); + + localOptimiser.PrepareRun(localNetwork, handler); + localNetwork.Run(handler, trainingPass: true); + localOptimiser.Run(localNetwork, handler); + ApplyValueModifiers(localRegistry, handler); + } + + private void ApplyValueModifiers(IRegistry localRegistry, IComputationHandler handler) + { + if (_valueModifiers.Count == 0) + { + return; + } + + RegistryResolver resolver = new RegistryResolver(localRegistry); + + foreach (string identifier in _valueModifiers.Keys) + { + string[] fullyResolvedIdentifiers; + object[] values = resolver.ResolveGet(identifier, out fullyResolvedIdentifiers); + + for (int i = 0; i < values.Length; i++) + { + object value = values[i]; + INDArray asNDArray = value as INDArray; + INumber asNumber = value as INumber; + + if (asNDArray != null) + { + foreach (IValueModifier modifier in _valueModifiers[identifier]) + { + asNDArray = modifier.Modify(fullyResolvedIdentifiers[i], asNDArray, handler); + } + values[i] = asNDArray; + } + else if (asNumber != null) + { + foreach (IValueModifier modifier in _valueModifiers[identifier]) + { + asNumber = modifier.Modify(fullyResolvedIdentifiers[i], asNumber, handler); + } + values[i] = asNumber; + } + else + { + double? asDouble = value as double?; + + if (asDouble != null) + { + foreach (IValueModifier modifier in _valueModifiers[identifier]) + { + asDouble = modifier.Modify(fullyResolvedIdentifiers[i], asDouble.Value, handler); + } + values[i] = asDouble.Value; + } + } + + resolver.ResolveSet(fullyResolvedIdentifiers[i], values[i]); + } + } + } + + /// + /// Provide the external data to a network given the current record block (typically as given by the training data iterator). + /// + /// The network to provide the data with. + /// The current record block. + public void ProvideExternalInputData(INetwork localNetwork, IDictionary currentBlock) + { + CheckInitialised(); + + DataProviderUtils.ProvideExternalInputData(DataProvider, localNetwork, currentBlock); + } + + /// + /// Provide the external output data from network to the data provider. + /// + /// The network to get the data from. + /// The current record block. + public void ProvideExternalOutputData(INetwork localNetwork, IDictionary currentBlock) + { + DataProviderUtils.ProvideExternalOutputData(DataProvider, localNetwork, currentBlock); + } + + /// + /// Reset this trainer to an un-initialised state, discard all progress information. If necessary, stop the operator (and wait for that). + /// + public void Reset() + { + _logger.Info($"Resetting trainer \"{Name}\" to un-initialised state, discarding all progress data..."); + + if (Operator?.State != ExecutionState.None) + { + _logger.Info($"Signalling operator to stop and reset, waiting for state change signal to continue trainer reset..."); + + Operator.SignalReset(); + Operator.WaitForStateChanged(); + } + + Network?.Reset(); + _initialised = false; + + UpdateRegistry(); + + _logger.Info($"Done resetting trainer \"{Name}\" to un-initialised state, discarded all progress data and stopped operator."); + } + + public override string ToString() + { + return $"trainer \"{Name}\""; + } + } } \ No newline at end of file diff --git a/Sigma.Core/Utils/ArrayUtils.cs b/Sigma.Core/Utils/ArrayUtils.cs index ee50396e..622679bf 100644 --- a/Sigma.Core/Utils/ArrayUtils.cs +++ b/Sigma.Core/Utils/ArrayUtils.cs @@ -18,6 +18,25 @@ namespace Sigma.Core.Utils /// public static class ArrayUtils { + /// + /// Concatenate two given arrays into one result array (b is appended after a). + /// + /// The array element type. + /// The first array. + /// The second array. + /// A concatenated array of a and b. + public static T[] Concatenate(T[] a, T[] b) + { + if (a == null) throw new ArgumentNullException(nameof(a)); + if (b == null) throw new ArgumentNullException(nameof(b)); + + T[] result = new T[a.Length + b.Length]; + a.CopyTo(result, 0); + b.CopyTo(result, a.Length); + + return result; + } + /// /// The product of an integer array (i.e. all values multiplied with each other). /// diff --git a/Sigma.Core/Utils/LoggingUtils.cs b/Sigma.Core/Utils/LoggingUtils.cs index 5156b49d..2039964e 100644 --- a/Sigma.Core/Utils/LoggingUtils.cs +++ b/Sigma.Core/Utils/LoggingUtils.cs @@ -6,6 +6,7 @@ MIT License For full license see LICENSE in the root directory of this project. */ +using System; using log4net; using log4net.Core; @@ -24,26 +25,30 @@ public static class LoggingUtils /// The logger to use. public static void Log(Level level, string message, ILog logger) { - if (level == Level.Fatal) + if (level.Value == Level.Fatal.Value) { logger.Fatal(message); } - else if (level == Level.Error) + else if (level.Value == Level.Error.Value) { logger.Error(message); } - else if (level == Level.Warn) + else if (level.Value == Level.Warn.Value) { logger.Warn(message); } - else if (level == Level.Info) + else if (level.Value == Level.Info.Value) { logger.Info(message); } - else if (level == Level.Debug) + else if (level.Value == Level.Debug.Value) { logger.Debug(message); } - } + else + { + throw new ArgumentException($"Level {level} is not a supported logging level (supported levels are fatal, error, warn, info, debug)."); + } + } } } diff --git a/Sigma.Core/Utils/Namers.cs b/Sigma.Core/Utils/Namers.cs new file mode 100644 index 00000000..f69e6ff7 --- /dev/null +++ b/Sigma.Core/Utils/Namers.cs @@ -0,0 +1,160 @@ +/* +MIT License + +Copyright (c) 2016-2017 Florian Cäsar, Michael Plainer + +For full license see LICENSE in the root directory of this project. +*/ + +using System; + +namespace Sigma.Core.Utils +{ + /// + /// A common namer interface for static and dynamic naming of things. Any things. + /// + public interface INamer + { + /// + /// Get the name using a certain registry (and the corresponding resolver) for a certain sender. + /// Note: The sender may be used to get extra + /// + /// The parameter registry. + /// The resolver to the parameter registry. + /// The sender of this naming request. + /// The name using the givne information. + string GetName(IRegistry registry, IRegistryResolver resolver, object sender); + } + + /// + /// A static namer using a ... static name. + /// + [Serializable] + public class StaticNamer : INamer + { + private readonly string _name; + + /// + /// Create a static namer for a certain name. + /// + /// The name. + public StaticNamer(string name) + { + if (name == null) throw new ArgumentNullException(nameof(name)); + + _name = name; + } + + /// + public string GetName(IRegistry registry, IRegistryResolver resolver, object sender) + { + return _name; + } + } + + /// + /// Create a dynamic namer using a certain lambda function for . + /// + [Serializable] + public class DynamicLambdaNamer : INamer + { + private readonly Func _nameFunction; + + /// + /// Create a dynamic namer using the lambda function + /// + /// The name function + public DynamicLambdaNamer(Func nameFunction) + { + if (nameFunction == null) throw new ArgumentNullException(nameof(nameFunction)); + + _nameFunction = nameFunction; + } + + /// + public string GetName(IRegistry registry, IRegistryResolver resolver, object sender) + { + return _nameFunction.Invoke(registry, resolver, sender); + } + } + + /// + /// An dynamic namer using individual parameters as items in a format string. + /// + [Serializable] + public class DynamicItemisedNamer : INamer + { + private readonly string _formatString; + private readonly string[] _parameterIdentifiers; + private readonly object[] _bufferParameters; + + /// + /// Create a dynamic itemised namer using a format string and parameter identifiers (which will be resolved to the given values). + /// Note: Parameter order is preserved. + /// + /// The format string. + /// The parameter identifiers. + public DynamicItemisedNamer(string formatString, params string[] parameterIdentifiers) + { + if (formatString == null) throw new ArgumentNullException(nameof(formatString)); + if (parameterIdentifiers == null) throw new ArgumentNullException(nameof(parameterIdentifiers)); + + _formatString = formatString; + _parameterIdentifiers = parameterIdentifiers; // not sure, maybe copy? + _bufferParameters = new object[parameterIdentifiers.Length]; + } + + /// + public string GetName(IRegistry registry, IRegistryResolver resolver, object sender) + { + for (int i = 0; i < _parameterIdentifiers.Length; i++) + { + _bufferParameters[i] = resolver.ResolveGetSingle(_parameterIdentifiers[i]); + } + + string name = string.Format(_formatString, _bufferParameters); + + for (var i = 0; i < _bufferParameters.Length; i++) + { + _bufferParameters[i] = null; + } + + return name; + } + } + + /// + /// A utility collection for various static and dynamic namers. + /// + public static class Namers + { + /// + /// A static namer using a ... static name. + /// + public static INamer Static(string name) + { + return new StaticNamer(name); + } + + /// + /// Create a dynamic namer using the lambda function + /// + /// The name function + public static INamer Dynamic(Func nameFunction) + { + return new DynamicLambdaNamer(nameFunction); + } + + /// + /// Create a dynamic itemised namer using a format string and parameter identifiers (which will be resolved to the given values). + /// Note: Parameter order is preserved. + /// + /// The format string. + /// The parameter identifiers. + // TODO fix this attr, for some reason can't be found [StringFormatMethod("formatString")] + public static INamer Dynamic(string formatString, params string[] parameterIdentifiers) + { + return new DynamicItemisedNamer(formatString, parameterIdentifiers); + } + } +} diff --git a/Sigma.Core/Utils/Registry.cs b/Sigma.Core/Utils/Registry.cs index 6ffa285f..94649cc6 100644 --- a/Sigma.Core/Utils/Registry.cs +++ b/Sigma.Core/Utils/Registry.cs @@ -13,6 +13,7 @@ For full license see LICENSE in the root directory of this project. using System.Linq; using System.Text; using System.Text.RegularExpressions; +using Sigma.Core.Persistence; namespace Sigma.Core.Utils { @@ -22,11 +23,14 @@ namespace Sigma.Core.Utils /// Registries can be chained and represent a hierarchy, which can then be referred to using dot notation. /// [Serializable] - public class Registry : IRegistry + public class Registry : IRegistry, ISerialisationNotifier { internal Dictionary MappedValues; internal Dictionary AssociatedTypes; + [NonSerialized] + private ISet _hierarchyChangeListeners; + public bool CheckTypes { get; set; @@ -51,7 +55,10 @@ public bool ExceptionOnCopyNonDeepCopyable public ISet Tags { get; } - public ISet HierarchyChangeListeners { get; } + public ISet HierarchyChangeListeners + { + get { return _hierarchyChangeListeners; } + } /// /// Create a registry with a certain (optional) parent and an (optional) list of tags. @@ -72,7 +79,7 @@ public Registry(IRegistry parent = null, params string[] tags) } Tags = new HashSet(tags); - HierarchyChangeListeners = new HashSet(); + _hierarchyChangeListeners = new HashSet(); } public object DeepCopy() @@ -373,6 +380,11 @@ public IEnumerator GetValueIterator() } public override string ToString() + { + return $"registry tagged as {(Tags.Count == 0 ? "" : string.Join("", Tags))} with {MappedValues.Count} entries"; + } + + public string FancyToString() { StringBuilder str = new StringBuilder(); @@ -403,6 +415,28 @@ public bool RegistryContentEquals(IRegistry other) { return other != null && MappedValues.Count == other.Count && MappedValues.Keys.All(k => other.ContainsKey(k) && Equals(MappedValues[k], other[k])); } + + /// + /// Called before this object is serialised. + /// + public void OnSerialising() + { + } + + /// + /// Called after this object was serialised. + /// + public void OnSerialised() + { + } + + /// + /// Called after this object was de-serialised. + /// + public void OnDeserialised() + { + _hierarchyChangeListeners = new HashSet(); + } } /// diff --git a/Sigma.Core/Utils/RegistryResolver.cs b/Sigma.Core/Utils/RegistryResolver.cs index 70158acf..f5b2f750 100644 --- a/Sigma.Core/Utils/RegistryResolver.cs +++ b/Sigma.Core/Utils/RegistryResolver.cs @@ -344,7 +344,7 @@ private void AddMatchingIdentifiersFromRegistryTree(int hierarchyLevel, int last bool noneMatched = true; - foreach (string identifier in currentRootAtLevel.Keys) + foreach (string identifier in currentRootAtLevel.Keys.ToArray()) // TODO ugly hack, toarray is inefficient and is just to prevent "random" concurrent modification exception { if (regex.IsMatch(identifier)) { diff --git a/Sigma.Tests.Internals.Backend/Program.cs b/Sigma.Tests.Internals.Backend/Program.cs index ac83ab90..74d49456 100644 --- a/Sigma.Tests.Internals.Backend/Program.cs +++ b/Sigma.Tests.Internals.Backend/Program.cs @@ -17,6 +17,7 @@ using Sigma.Core.MathAbstract; using Sigma.Core.MathAbstract.Backends.SigmaDiff; using Sigma.Core.Monitors.Synchronisation; +using Sigma.Core.Persistence; using Sigma.Core.Training; using Sigma.Core.Training.Hooks; using Sigma.Core.Training.Hooks.Reporters; @@ -31,371 +32,392 @@ using System.Diagnostics; using System.Linq; using System.Threading; +using Sigma.Core.Training.Hooks.Processors; +using Sigma.Core.Training.Hooks.Saviors; +using Sigma.Core.Training.Optimisers.Gradient; namespace Sigma.Tests.Internals.Backend { - public static class Program - { - public static MinibatchIterator TrainingIterator; - - private static void Main(string[] args) - { - SigmaEnvironment.EnableLogging(xml: true); - SigmaEnvironment.Globals["web_proxy"] = WebUtils.GetProxyFromFileOrDefault(".customproxy"); - - SampleTrainerOperatorWorkerMnist(); - - Console.WriteLine("Program ended, waiting for termination, press any key..."); - Console.ReadKey(); - } - - private static void SampleTrainerOperatorWorkerIris() - { - SigmaEnvironment sigma = SigmaEnvironment.Create("trainer_test"); - - sigma.Prepare(); - - var irisReader = new CsvRecordReader(new MultiSource(new FileSource("iris.data"), new UrlSource("http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"))); - IRecordExtractor irisExtractor = irisReader.Extractor("inputs", new[] { 0, 3 }, "targets", 4).AddValueMapping(4, "Iris-setosa", "Iris-versicolor", "Iris-virginica") - .Preprocess(new OneHotPreprocessor(sectionName: "targets", minValue: 0, maxValue: 2)) - .Preprocess(new AdaptiveNormalisingPreprocessor(minOutputValue: 0.0, maxOutputValue: 1.0)); - - IDataset dataset = new Dataset("iris", Dataset.BlockSizeAuto, irisExtractor); - - ITrainer trainer = sigma.CreateGhostTrainer("test"); - - trainer.Network = new Network(); - trainer.Network.Architecture = InputLayer.Construct(4) - + FullyConnectedLayer.Construct(12) - + FullyConnectedLayer.Construct(10) - + FullyConnectedLayer.Construct(3) - + OutputLayer.Construct(3) - + SoftMaxCrossEntropyCostLayer.Construct(); - trainer.TrainingDataIterator = new MinibatchIterator(4, dataset); - trainer.AddNamedDataIterator("validation", new UndividedIterator(dataset)); - trainer.Optimiser = new AdadeltaOptimiser(decayRate: 0.9); - trainer.Operator = new CpuSinglethreadedOperator(new DebugHandler(new CpuFloat32Handler())); + public static class Program + { + public static MinibatchIterator TrainingIterator; - trainer.AddInitialiser("*.weights", new GaussianInitialiser(standardDeviation: 0.2)); - trainer.AddInitialiser("*.bias*", new GaussianInitialiser(standardDeviation: 0.1, mean: 0.0)); - - trainer.AddGlobalHook(new StopTrainingHook(atEpoch: 100)); - //trainer.AddLocalHook(new EarlyStopperHook("optimiser.cost_total", 20, target: ExtremaTarget.Min)); - trainer.AddHook(new ValueReporterHook("optimiser.cost_total", TimeStep.Every(1, TimeScale.Epoch))); - trainer.AddHook(new ValidationAccuracyReporter("validation", TimeStep.Every(1, TimeScale.Epoch), tops: 1)); - trainer.AddHook(new RunningTimeReporter(TimeStep.Every(1, TimeScale.Epoch))); + private static void Main(string[] args) + { + SigmaEnvironment.EnableLogging(xml: true); + SigmaEnvironment.Globals["web_proxy"] = WebUtils.GetProxyFromFileOrDefault(".customproxy"); - //trainer.AddGlobalHook(new CurrentEpochIterationReporter(TimeStep.Every(1, TimeScale.Epoch))); - - //Serialisation.WriteBinaryFile(trainer, "trainer.sgtrainer"); - //trainer = Serialisation.ReadBinaryFile("trainer.sgtrainer"); - - sigma.AddTrainer(trainer); + SampleXOR(); - //trainer.Operator.InvokeCommand(new TestCommand(() => { throw new NotImplementedException(); }, "optimiser.learning_rate")); - trainer.Operator.InvokeCommand(new SetValueCommand("optimiser.learning_rate", 0.02d, () => {/* finished */})); - - sigma.Run(); - } - - [Serializable] - private class TestCommand : BaseCommand - { - private readonly ILog _log = LogManager.GetLogger(typeof(TestCommand)); - public TestCommand(Action onFinish = null, params string[] requiredRegistryEntries) : base(onFinish, requiredRegistryEntries) - { - _log.Info("Test command created"); - } + Console.WriteLine("Program ended, waiting for termination, press any key..."); + Console.ReadKey(); + } - /// - /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied. - /// - /// The registry containing the required values for this hook's execution. - /// A helper resolver for complex registry entries (automatically cached). - public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) - { - _log.Info("Test command invoked"); - //resolver.ResolveSet("optimiser.learning_rate", 10); - } - } + private static void SampleXOR() + { + SigmaEnvironment sigma = SigmaEnvironment.Create("xor"); + sigma.Prepare(); - private static void SampleTrainerOperatorWorkerMnist() - { - SigmaEnvironment sigma = SigmaEnvironment.Create("trainer_test"); + RawDataset dataset = new RawDataset("xor"); + dataset.AddRecords("inputs", new[] { 0, 0 }, new[] { 0, 1 }, new[] { 1, 0 }, new[] { 1, 1 }); + dataset.AddRecords("targets", new[] { 0 }, new[] { 1 }, new[] { 1 }, new[] { 0 }); - sigma.Prepare(); + ITrainer trainer = sigma.CreateTrainer("xor-trainer"); - ByteRecordReader mnistImageReader = new ByteRecordReader(headerLengthBytes: 16, recordSizeBytes: 28 * 28, source: new CompressedSource(new MultiSource(new FileSource("train-images-idx3-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")))); - IRecordExtractor mnistImageExtractor = mnistImageReader.Extractor("inputs", new[] { 0L, 0L }, new[] { 28L, 28L }).Preprocess(new NormalisingPreprocessor(0, 255)); + trainer.Network = new Network(); + trainer.Network.Architecture = InputLayer.Construct(2) + FullyConnectedLayer.Construct(2) + FullyConnectedLayer.Construct(1) + OutputLayer.Construct(1) + SquaredDifferenceCostLayer.Construct(); + trainer.TrainingDataIterator = new MinibatchIterator(1, dataset); + trainer.Operator = new CpuSinglethreadedOperator(); + trainer.Optimiser = new GradientDescentOptimiser(learningRate: 0.01); - ByteRecordReader mnistTargetReader = new ByteRecordReader(headerLengthBytes: 8, recordSizeBytes: 1, source: new CompressedSource(new MultiSource(new FileSource("train-labels-idx1-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")))); - IRecordExtractor mnistTargetExtractor = mnistTargetReader.Extractor("targets", new[] { 0L }, new[] { 1L }).Preprocess(new OneHotPreprocessor(minValue: 0, maxValue: 9)); + trainer.AddInitialiser("*.*", new GaussianInitialiser(standardDeviation: 0.1)); - IDataset dataset = new Dataset("mnist-training", Dataset.BlockSizeAuto, mnistImageExtractor, mnistTargetExtractor); - ITrainer trainer = sigma.CreateTrainer("test"); + trainer.AddLocalHook(new ValueReporterHook("optimiser.cost_total", TimeStep.Every(1, TimeScale.Epoch), reportEpochIteration: true)); - trainer.Network = new Network(); - trainer.Network.Architecture = InputLayer.Construct(28, 28) - + FullyConnectedLayer.Construct(28 * 28) - + FullyConnectedLayer.Construct(10) - + OutputLayer.Construct(10) - + SoftMaxCrossEntropyCostLayer.Construct(); - trainer.TrainingDataIterator = new MinibatchIterator(20, dataset); - trainer.Optimiser = new AdagradOptimiser(baseLearningRate: 0.02); - trainer.Operator = new CpuSinglethreadedOperator(); + sigma.Run(); + } - trainer.AddInitialiser("*.weights", new GaussianInitialiser(standardDeviation: 0.25f)); - trainer.AddInitialiser("*.bias*", new GaussianInitialiser(standardDeviation: 0.01f, mean: 0.03f)); + private static void SampleIris() + { + SigmaEnvironment sigma = SigmaEnvironment.Create("trainer_test"); + sigma.SetRandomSeed(137); - trainer.AddGlobalHook(new CurrentEpochIterationReporter(TimeStep.Every(5, TimeScale.Iteration))); - trainer.AddLocalHook(new ValueReporterHook("optimiser.cost_total", TimeStep.Every(5, TimeScale.Iteration))); + sigma.Prepare(); - sigma.Run(); - } + var irisReader = new CsvRecordReader(new MultiSource(new FileSource("iris.data"), new UrlSource("http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"))); + IRecordExtractor irisExtractor = irisReader.Extractor("inputs", new[] { 0, 3 }, "targets", 4).AddValueMapping(4, "Iris-setosa", "Iris-versicolor", "Iris-virginica") + .Preprocess(new OneHotPreprocessor("targets", minValue: 0, maxValue: 2)) + .Preprocess(new AdaptiveNormalisingPreprocessor(minOutputValue: 0.0, maxOutputValue: 1.0)) + .Preprocess(new ShufflePreprocessor()); - private static void SampleCachedFastIteration() - { - SigmaEnvironment sigma = SigmaEnvironment.Create("test"); - - IDataSource dataSource = new CompressedSource(new MultiSource(new FileSource("train-images-idx3-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))); - - ByteRecordReader mnistImageReader = new ByteRecordReader(headerLengthBytes: 16, recordSizeBytes: 28 * 28, source: dataSource); - IRecordExtractor mnistImageExtractor = mnistImageReader.Extractor("inputs", new[] { 0L, 0L }, new[] { 28L, 28L }).Preprocess(new NormalisingPreprocessor(0, 255)); - - IDataset dataset = new Dataset("mnist-training", Dataset.BlockSizeAuto, mnistImageExtractor); - IDataset[] slices = dataset.SplitRecordwise(0.8, 0.2); - IDataset trainingData = slices[0]; - - Stopwatch stopwatch = Stopwatch.StartNew(); - - IDataIterator iterator = new MinibatchIterator(10, trainingData); - foreach (var block in iterator.Yield(new CpuFloat32Handler(), sigma)) - { - //PrintFormattedBlock(block, PrintUtils.AsciiGreyscalePalette); - } + IDataset dataset = new ExtractedDataset("iris", ExtractedDataset.BlockSizeAuto, false, irisExtractor); - Console.Write("\nFirst iteration took " + stopwatch.Elapsed + "\n+=+ Iterating over dataset again +=+ Dramatic pause..."); + ITrainer trainer = sigma.CreateGhostTrainer("test"); - ArrayUtils.Range(1, 10).ToList().ForEach(i => - { - Thread.Sleep(500); - Console.Write("."); - }); + trainer.Network = new Network(); + trainer.Network.Architecture = InputLayer.Construct(4) + + FullyConnectedLayer.Construct(4) + + 2 * FullyConnectedLayer.Construct(24) + + FullyConnectedLayer.Construct(3) + + OutputLayer.Construct(3) + + SoftMaxCrossEntropyCostLayer.Construct(); + //trainer.Network = Serialisation.ReadBinaryFileIfExists("iris.sgnet", trainer.Network); - stopwatch.Restart(); + trainer.TrainingDataIterator = new MinibatchIterator(10, dataset); + trainer.AddNamedDataIterator("validation", new UndividedIterator(dataset)); + trainer.Optimiser = new AdagradOptimiser(baseLearningRate: 0.01); + trainer.Operator = new CpuSinglethreadedOperator(); - foreach (var block in iterator.Yield(new CpuFloat32Handler(), sigma)) - { - //PrintFormattedBlock(block, PrintUtils.AsciiGreyscalePalette); - } + trainer.AddInitialiser("*.weights", new GaussianInitialiser(standardDeviation: 0.3)); + trainer.AddInitialiser("*.bias*", new GaussianInitialiser(standardDeviation: 0.1)); - Console.WriteLine("Second iteration took " + stopwatch.Elapsed); - } + //trainer.AddGlobalHook(new StopTrainingHook(atEpoch: 100)); + //trainer.AddLocalHook(new EarlyStopperHook("optimiser.cost_total", 20, target: ExtremaTarget.Min)); - private static void SampleDotProduct() - { - IComputationHandler handler = new CpuFloat32Handler(); + trainer.AddLocalHook(new ValueReporterHook("optimiser.cost_total", TimeStep.Every(1, TimeScale.Epoch), reportEpochIteration: true)); + //.On(new ExtremaCriteria("optimiser.cost_total", ExtremaTarget.Min))); + trainer.AddLocalHook(new DiskSaviorHook("network.self", Namers.Dynamic("iris_epoch{0}.sgnet", "epoch"), verbose: true) + .On(new ExtremaCriteria("optimiser.cost_total", ExtremaTarget.Min))); - INDArray a = handler.NDArray(ArrayUtils.Range(1, 6), 3, 2); - INDArray b = handler.NDArray(ArrayUtils.Range(1, 6), 2, 3); - - Console.WriteLine("a = " + ArrayUtils.ToString(a, (ADNDArray.ToStringElement) null, 0, true)); - Console.WriteLine("b = " + ArrayUtils.ToString(b, (ADNDArray.ToStringElement) null, 0, true)); + trainer.AddHook(new ValidationAccuracyReporter("validation", TimeStep.Every(1, TimeScale.Epoch), tops: 1)); + trainer.AddHook(new StopTrainingHook(new ThresholdCriteria("shared.validation_accuracy_top1", ComparisonTarget.GreaterThanEquals, 0.95))); - INDArray c = handler.Dot(a, b); + Serialisation.WriteBinaryFile(trainer, "trainer.sgtrainer"); + trainer = Serialisation.ReadBinaryFile("trainer.sgtrainer"); - Console.WriteLine("c = " + ArrayUtils.ToString(c, (ADNDArray.ToStringElement) null, 0, true)); - } + sigma.AddTrainer(trainer); - private static void SampleNetworkMerging() - { - SigmaEnvironment sigma = SigmaEnvironment.Create("merge_test"); + sigma.Run(); + } - ITrainer[] trainers = new ITrainer[3]; - int[] constantValues = { 2, 10, 70 }; + private static void SampleMnist() + { + SigmaEnvironment sigma = SigmaEnvironment.Create("trainer_test"); - //INetworkMerger merger = new WeightedNetworkMerger(10d, 10d, 1d); - INetworkMerger merger = new AverageNetworkMerger(); - IComputationHandler handler = new CpuFloat32Handler(); + sigma.Prepare(); - for (int i = 0; i < trainers.Length; i++) - { - trainers[i] = sigma.CreateTrainer($"MergeTrainer{i}"); - trainers[i].Network = new Network($"{i}"); - trainers[i].Network.Architecture = InputLayer.Construct(2, 2) + ElementwiseLayer.Construct(2 * 2) + OutputLayer.Construct(2); + ByteRecordReader mnistImageReader = new ByteRecordReader(headerLengthBytes: 16, recordSizeBytes: 28 * 28, source: new CompressedSource(new MultiSource(new FileSource("train-images-idx3-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")))); + IRecordExtractor mnistImageExtractor = mnistImageReader.Extractor("inputs", new[] { 0L, 0L }, new[] { 28L, 28L }).Preprocess(new NormalisingPreprocessor(0, 255)); - trainers[i].AddInitialiser("*.weights", new ConstantValueInitialiser(constantValues[i])); + ByteRecordReader mnistTargetReader = new ByteRecordReader(headerLengthBytes: 8, recordSizeBytes: 1, source: new CompressedSource(new MultiSource(new FileSource("train-labels-idx1-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")))); + IRecordExtractor mnistTargetExtractor = mnistTargetReader.Extractor("targets", new[] { 0L }, new[] { 1L }).Preprocess(new OneHotPreprocessor(minValue: 0, maxValue: 9)); - trainers[i].Operator = new CpuMultithreadedOperator(5); - trainers[i].Initialise(handler); - } + IDataset dataset = new ExtractedDataset("mnist", ExtractedDataset.BlockSizeAuto, false, mnistImageExtractor, mnistTargetExtractor); + ITrainer trainer = sigma.CreateTrainer("test"); - foreach (ITrainer trainer in trainers) - { - Console.WriteLine(trainer.Network.Registry); - } + trainer.Network = new Network(); + trainer.Network.Architecture = InputLayer.Construct(28, 28) + + FullyConnectedLayer.Construct(28 * 28) + + FullyConnectedLayer.Construct(10) + + OutputLayer.Construct(10) + + SoftMaxCrossEntropyCostLayer.Construct(); + trainer.Network = Serialisation.ReadBinaryFileIfExists("mnist.sgnet", trainer.Network); + trainer.TrainingDataIterator = new MinibatchIterator(100, dataset); + trainer.AddNamedDataIterator("validation", new UndividedIterator(dataset)); + trainer.Optimiser = new AdagradOptimiser(baseLearningRate: 0.01); + trainer.Operator = new CpuSinglethreadedOperator(); - merger.AddMergeEntry("layers.*.weights"); - merger.Merge(trainers[1].Network, trainers[2].Network, handler); + trainer.AddInitialiser("*.weights", new GaussianInitialiser(standardDeviation: 0.1)); + trainer.AddInitialiser("*.bias*", new GaussianInitialiser(standardDeviation: 0.05)); - Console.WriteLine("*******************"); - foreach (ITrainer trainer in trainers) - { - Console.WriteLine(trainer.Network.Registry); - } - } + trainer.AddLocalHook(new ValueReporterHook("optimiser.cost_total", TimeStep.Every(1, TimeScale.Epoch), reportEpochIteration: true)); + //trainer.AddLocalHook(new ValueReporterHook("optimiser.cost_total", TimeStep.Every(1, TimeScale.Iteration), reportEpochIteration: true) + // .On(new ExtremaCriteria("optimiser.cost_total", ExtremaTarget.Min))); + trainer.AddLocalHook(new DiskSaviorHook("network.self", "mnist.sgnet", verbose: true) + .On(new ExtremaCriteria("optimiser.cost_total", ExtremaTarget.Min))); - private static void SampleNetworkArchitecture() - { - SigmaEnvironment sigma = SigmaEnvironment.Create("test"); + var validationTimeStep = TimeStep.Every(1, TimeScale.Epoch); - IComputationHandler handler = new CpuFloat32Handler(); - ITrainer trainer = sigma.CreateTrainer("test_trainer"); - trainer.Network = new Network(); - trainer.Network.Architecture = InputLayer.Construct(2, 2) + - ElementwiseLayer.Construct(2 * 2) + - FullyConnectedLayer.Construct(2) + - 2 * (FullyConnectedLayer.Construct(4) + FullyConnectedLayer.Construct(2)) + - OutputLayer.Construct(2); - trainer.Network = (INetwork) trainer.Network.DeepCopy(); + trainer.AddHook(new ValidationAccuracyReporter("validation", validationTimeStep, tops: 1)); + trainer.AddHook(new StopTrainingHook(new ThresholdCriteria("shared.validation_accuracy_top1", ComparisonTarget.GreaterThanEquals, 0.5), validationTimeStep)); + trainer.AddHook(new StopTrainingHook(atEpoch: 500)); - trainer.Operator = new CpuMultithreadedOperator(10); + sigma.Run(); + } - trainer.AddInitialiser("*.weights", new GaussianInitialiser(standardDeviation: 0.1f)); - trainer.AddInitialiser("*.bias*", new GaussianInitialiser(standardDeviation: 0.01f, mean: 0.03f)); - trainer.Initialise(handler); + private static void SampleCachedFastIteration() + { + SigmaEnvironment sigma = SigmaEnvironment.Create("test"); - trainer.Network = (INetwork) trainer.Network.DeepCopy(); + IDataSource dataSource = new CompressedSource(new MultiSource(new FileSource("train-images-idx3-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))); - Console.WriteLine(trainer.Network.Registry); + ByteRecordReader mnistImageReader = new ByteRecordReader(headerLengthBytes: 16, recordSizeBytes: 28 * 28, source: dataSource); + IRecordExtractor mnistImageExtractor = mnistImageReader.Extractor("inputs", new[] { 0L, 0L }, new[] { 28L, 28L }).Preprocess(new NormalisingPreprocessor(0, 255)); - IRegistryResolver resolver = new RegistryResolver(trainer.Network.Registry); + IDataset dataset = new ExtractedDataset("mnist-training", ExtractedDataset.BlockSizeAuto, mnistImageExtractor); + IDataset[] slices = dataset.SplitRecordwise(0.8, 0.2); + IDataset trainingData = slices[0]; - Console.WriteLine("==============="); - object[] weights = resolver.ResolveGet("layers.*.weights"); - Console.WriteLine(string.Join("\n", weights)); - Console.WriteLine("==============="); + Stopwatch stopwatch = Stopwatch.StartNew(); + IDataIterator iterator = new MinibatchIterator(10, trainingData); + foreach (var block in iterator.Yield(new CpuFloat32Handler(), sigma)) + { + //PrintFormattedBlock(block, PrintUtils.AsciiGreyscalePalette); + } + Console.Write("\nFirst iteration took " + stopwatch.Elapsed + "\n+=+ Iterating over dataset again +=+ Dramatic pause..."); - //foreach (ILayerBuffer buffer in trainer.Network.YieldLayerBuffersOrdered()) - //{ - // Console.WriteLine(buffer.Layer.Name + ": "); + ArrayUtils.Range(1, 10).ToList().ForEach(i => + { + Thread.Sleep(500); + Console.Write("."); + }); - // Console.WriteLine("inputs:"); - // foreach (string input in buffer.Inputs.Keys) - // { - // Console.WriteLine($"\t{input}: {buffer.Inputs[input].GetHashCode()}"); - // } + stopwatch.Restart(); - // Console.WriteLine("outputs:"); - // foreach (string output in buffer.Outputs.Keys) - // { - // Console.WriteLine($"\t{output}: {buffer.Outputs[output].GetHashCode()}"); - // } - //} - } + foreach (var block in iterator.Yield(new CpuFloat32Handler(), sigma)) + { + //PrintFormattedBlock(block, PrintUtils.AsciiGreyscalePalette); + } - private static void SampleAutomaticDifferentiation() - { - IComputationHandler handler = new CpuFloat32Handler(); + Console.WriteLine("Second iteration took " + stopwatch.Elapsed); + } - uint traceTag = handler.BeginTrace(); + private static void SampleDotProduct() + { + IComputationHandler handler = new CpuFloat32Handler(); - INDArray array = handler.NDArray(ArrayUtils.Range(1, 6), 2, 3); - INumber a = handler.Number(-1.0f), b = handler.Number(3.0f); + INDArray a = handler.NDArray(ArrayUtils.Range(1, 6), 3, 2); + INDArray b = handler.NDArray(ArrayUtils.Range(1, 6), 2, 3); - INumber c = handler.Trace(handler.Add(a, b), traceTag); - INumber d = handler.Multiply(c, 2); - INumber e = handler.Add(d, handler.Add(c, 3)); - INumber f = handler.SquareRoot(e); + Console.WriteLine("a = " + ArrayUtils.ToString(a, (ADNDArray.ToStringElement)null, 0, true)); + Console.WriteLine("b = " + ArrayUtils.ToString(b, (ADNDArray.ToStringElement)null, 0, true)); - array = handler.Multiply(array, f); + INDArray c = handler.Dot(a, b); - INumber cost = handler.Divide(handler.Sum(array), array.Length); + Console.WriteLine("c = " + ArrayUtils.ToString(c, (ADNDArray.ToStringElement)null, 0, true)); + } - Console.WriteLine("cost: " + cost); + private static void SampleNetworkMerging() + { + SigmaEnvironment sigma = SigmaEnvironment.Create("merge_test"); - handler.ComputeDerivativesTo(cost); + ITrainer[] trainers = new ITrainer[3]; + int[] constantValues = { 2, 10, 70 }; - Console.WriteLine(array); - Console.WriteLine("f: " + handler.GetDerivative(f)); - Console.WriteLine("e: " + handler.GetDerivative(e)); - Console.WriteLine("d: " + handler.GetDerivative(d)); - Console.WriteLine("c: " + handler.GetDerivative(c)); - Console.WriteLine("a: " + handler.GetDerivative(array)); + //INetworkMerger merger = new WeightedNetworkMerger(10d, 10d, 1d); + INetworkMerger merger = new AverageNetworkMerger(); + IComputationHandler handler = new CpuFloat32Handler(); - handler.ComputeDerivativesTo(f); + for (int i = 0; i < trainers.Length; i++) + { + trainers[i] = sigma.CreateTrainer($"MergeTrainer{i}"); + trainers[i].Network = new Network($"{i}"); + trainers[i].Network.Architecture = InputLayer.Construct(2, 2) + ElementwiseLayer.Construct(2 * 2) + OutputLayer.Construct(2); - Console.WriteLine("f: " + handler.GetDerivative(f)); - Console.WriteLine("e: " + handler.GetDerivative(e)); - Console.WriteLine("d: " + handler.GetDerivative(d)); - Console.WriteLine("c: " + handler.GetDerivative(c)); - Console.WriteLine("a: " + handler.GetDerivative(array)); - } + trainers[i].AddInitialiser("*.weights", new ConstantValueInitialiser(constantValues[i])); - private static void SampleLoadExtractIterate() - { - SigmaEnvironment sigma = SigmaEnvironment.Create("test"); + trainers[i].Operator = new CpuMultithreadedOperator(5); + trainers[i].Initialise(handler); + } - sigma.Prepare(); + foreach (ITrainer trainer in trainers) + { + Console.WriteLine(trainer.Network.Registry); + } - //var irisReader = new CsvRecordReader(new MultiSource(new FileSource("iris.data"), new UrlSource("http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"))); - //IRecordExtractor irisExtractor = irisReader.Extractor("inputs2", new[] { 0, 3 }, "targets2", 4).AddValueMapping(4, "Iris-setosa", "Iris-versicolor", "Iris-virginica"); - //irisExtractor = irisExtractor.Preprocess(new OneHotPreprocessor(sectionName: "targets2", minValue: 0, maxValue: 2), new NormalisingPreprocessor(sectionNames: "inputs2", minInputValue: 0, maxInputValue: 6)); + merger.AddMergeEntry("layers.*.weights"); + merger.Merge(trainers[1].Network, trainers[2].Network, handler); - ByteRecordReader mnistImageReader = new ByteRecordReader(headerLengthBytes: 16, recordSizeBytes: 28 * 28, source: new CompressedSource(new MultiSource(new FileSource("train-images-idx3-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")))); - IRecordExtractor mnistImageExtractor = mnistImageReader.Extractor("inputs", new[] { 0L, 0L }, new[] { 28L, 28L }).Preprocess(new NormalisingPreprocessor(0, 255)); + Console.WriteLine("*******************"); + foreach (ITrainer trainer in trainers) + { + Console.WriteLine(trainer.Network.Registry); + } + } - ByteRecordReader mnistTargetReader = new ByteRecordReader(headerLengthBytes: 8, recordSizeBytes: 1, source: new CompressedSource(new MultiSource(new FileSource("train-labels-idx1-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")))); - IRecordExtractor mnistTargetExtractor = mnistTargetReader.Extractor("targets", new[] { 0L }, new[] { 1L }).Preprocess(new OneHotPreprocessor(minValue: 0, maxValue: 9)); + private static void SampleNetworkArchitecture() + { + SigmaEnvironment sigma = SigmaEnvironment.Create("test"); - IComputationHandler handler = new CpuFloat32Handler(); + IComputationHandler handler = new CpuFloat32Handler(); + ITrainer trainer = sigma.CreateTrainer("test_trainer"); + trainer.Network = new Network(); + trainer.Network.Architecture = InputLayer.Construct(2, 2) + + ElementwiseLayer.Construct(2 * 2) + + FullyConnectedLayer.Construct(2) + + 2 * (FullyConnectedLayer.Construct(4) + FullyConnectedLayer.Construct(2)) + + OutputLayer.Construct(2); + trainer.Network = (INetwork)trainer.Network.DeepCopy(); - Dataset dataset = new Dataset("mnist-training", Dataset.BlockSizeAuto, mnistImageExtractor, mnistTargetExtractor); - IDataset[] slices = dataset.SplitRecordwise(0.8, 0.2); - IDataset trainingData = slices[0]; - IDataset validationData = slices[1]; + trainer.Operator = new CpuMultithreadedOperator(10); - MinibatchIterator trainingIterator = new MinibatchIterator(1, trainingData); - MinibatchIterator validationIterator = new MinibatchIterator(1, validationData); + trainer.AddInitialiser("*.weights", new GaussianInitialiser(standardDeviation: 0.1f)); + trainer.AddInitialiser("*.bias*", new GaussianInitialiser(standardDeviation: 0.01f, mean: 0.03f)); + trainer.Initialise(handler); - while (true) - { - foreach (var block in trainingIterator.Yield(handler, sigma)) - { - Thread.Sleep(100); + trainer.Network = (INetwork)trainer.Network.DeepCopy(); - PrintFormattedBlock(block, PrintUtils.AsciiGreyscalePalette); + Console.WriteLine(trainer.Network.Registry); - Thread.Sleep(1000); - } - } + IRegistryResolver resolver = new RegistryResolver(trainer.Network.Registry); - //Random random = new Random(); - //INDArray array = new ADNDArray(3, 1, 2, 2); + Console.WriteLine("==============="); + object[] weights = resolver.ResolveGet("layers.*.weights"); + Console.WriteLine(string.Join("\n", weights)); + Console.WriteLine("==============="); - //new GaussianInitialiser(0.05, 0.05).Initialise(array, Handler, random); - //Console.WriteLine(array); - //new ConstantValueInitialiser(1).Initialise(array, Handler, random); + //foreach (ILayerBuffer buffer in trainer.Network.YieldLayerBuffersOrdered()) + //{ + // Console.WriteLine(buffer.Layer.Name + ": "); - //Console.WriteLine(array); + // Console.WriteLine("inputs:"); + // foreach (string input in buffer.Inputs.Keys) + // { + // Console.WriteLine($"\t{input}: {buffer.Inputs[input].GetHashCode()}"); + // } - //dataset.InvalidateAndClearCaches(); - } + // Console.WriteLine("outputs:"); + // foreach (string output in buffer.Outputs.Keys) + // { + // Console.WriteLine($"\t{output}: {buffer.Outputs[output].GetHashCode()}"); + // } + //} + } - private static void PrintFormattedBlock(IDictionary block, char[] palette) - { - foreach (string name in block.Keys) - { - string blockString = name == "inputs" - ? ArrayUtils.ToString(block[name], e => palette[(int) (e * (palette.Length - 1))].ToString(), maxDimensionNewLine: 0, printSeperator: false) - : block[name].ToString(); + private static void SampleAutomaticDifferentiation() + { + IComputationHandler handler = new CpuFloat32Handler(); - Console.WriteLine($"[{name}]=\n" + blockString); - } - } - } + uint traceTag = handler.BeginTrace(); + + INDArray array = handler.NDArray(ArrayUtils.Range(1, 6), 2, 3); + INumber a = handler.Number(-1.0f), b = handler.Number(3.0f); + + INumber c = handler.Trace(handler.Add(a, b), traceTag); + INumber d = handler.Multiply(c, 2); + INumber e = handler.Add(d, handler.Add(c, 3)); + INumber f = handler.SquareRoot(e); + + array = handler.Multiply(array, f); + + INumber cost = handler.Divide(handler.Sum(array), array.Length); + + Console.WriteLine("cost: " + cost); + + handler.ComputeDerivativesTo(cost); + + Console.WriteLine(array); + Console.WriteLine("f: " + handler.GetDerivative(f)); + Console.WriteLine("e: " + handler.GetDerivative(e)); + Console.WriteLine("d: " + handler.GetDerivative(d)); + Console.WriteLine("c: " + handler.GetDerivative(c)); + Console.WriteLine("a: " + handler.GetDerivative(array)); + + handler.ComputeDerivativesTo(f); + + Console.WriteLine("f: " + handler.GetDerivative(f)); + Console.WriteLine("e: " + handler.GetDerivative(e)); + Console.WriteLine("d: " + handler.GetDerivative(d)); + Console.WriteLine("c: " + handler.GetDerivative(c)); + Console.WriteLine("a: " + handler.GetDerivative(array)); + } + + private static void SampleLoadExtractIterate() + { + SigmaEnvironment sigma = SigmaEnvironment.Create("test"); + + sigma.Prepare(); + + //var irisReader = new CsvRecordReader(new MultiSource(new FileSource("iris.data"), new UrlSource("http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"))); + //IRecordExtractor irisExtractor = irisReader.Extractor("inputs2", new[] { 0, 3 }, "targets2", 4).AddValueMapping(4, "Iris-setosa", "Iris-versicolor", "Iris-virginica"); + //irisExtractor = irisExtractor.Preprocess(new OneHotPreprocessor(sectionName: "targets2", minValue: 0, maxValue: 2), new NormalisingPreprocessor(sectionNames: "inputs2", minInputValue: 0, maxInputValue: 6)); + + ByteRecordReader mnistImageReader = new ByteRecordReader(headerLengthBytes: 16, recordSizeBytes: 28 * 28, source: new CompressedSource(new MultiSource(new FileSource("train-images-idx3-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")))); + IRecordExtractor mnistImageExtractor = mnistImageReader.Extractor("inputs", new[] { 0L, 0L }, new[] { 28L, 28L }).Preprocess(new NormalisingPreprocessor(0, 255)); + + ByteRecordReader mnistTargetReader = new ByteRecordReader(headerLengthBytes: 8, recordSizeBytes: 1, source: new CompressedSource(new MultiSource(new FileSource("train-labels-idx1-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")))); + IRecordExtractor mnistTargetExtractor = mnistTargetReader.Extractor("targets", new[] { 0L }, new[] { 1L }).Preprocess(new OneHotPreprocessor(minValue: 0, maxValue: 9)); + + IComputationHandler handler = new CpuFloat32Handler(); + + ExtractedDataset dataset = new ExtractedDataset("mnist-training", ExtractedDataset.BlockSizeAuto, mnistImageExtractor, mnistTargetExtractor); + IDataset[] slices = dataset.SplitRecordwise(0.8, 0.2); + IDataset trainingData = slices[0]; + IDataset validationData = slices[1]; + + MinibatchIterator trainingIterator = new MinibatchIterator(1, trainingData); + MinibatchIterator validationIterator = new MinibatchIterator(1, validationData); + + while (true) + { + foreach (var block in trainingIterator.Yield(handler, sigma)) + { + Thread.Sleep(100); + + PrintFormattedBlock(block, PrintUtils.AsciiGreyscalePalette); + + Thread.Sleep(1000); + } + } + + //Random random = new Random(); + //INDArray array = new ADNDArray(3, 1, 2, 2); + + //new GaussianInitialiser(0.05, 0.05).Initialise(array, Handler, random); + + //Console.WriteLine(array); + + //new ConstantValueInitialiser(1).Initialise(array, Handler, random); + + //Console.WriteLine(array); + + //dataset.InvalidateAndClearCaches(); + } + + private static void PrintFormattedBlock(IDictionary block, char[] palette) + { + foreach (string name in block.Keys) + { + string blockString = name == "inputs" + ? ArrayUtils.ToString(block[name], e => palette[(int)(e * (palette.Length - 1))].ToString(), maxDimensionNewLine: 0, printSeperator: false) + : block[name].ToString(); + + Console.WriteLine($"[{name}]=\n" + blockString); + } + } + } } \ No newline at end of file diff --git a/Sigma.Tests.Internals.WPF/Program.cs b/Sigma.Tests.Internals.WPF/Program.cs index 55495b86..7ea19b00 100644 --- a/Sigma.Tests.Internals.WPF/Program.cs +++ b/Sigma.Tests.Internals.WPF/Program.cs @@ -10,105 +10,229 @@ using Sigma.Core.Layers.Cost; using Sigma.Core.Layers.External; using Sigma.Core.Layers.Feedforward; +using Sigma.Core.MathAbstract; using Sigma.Core.Monitors.WPF; +using Sigma.Core.Monitors.WPF.Model.UI.Resources; +using Sigma.Core.Monitors.WPF.Model.UI.StatusBar; using Sigma.Core.Monitors.WPF.Panels.Charts; using Sigma.Core.Monitors.WPF.Panels.Controls; +using Sigma.Core.Monitors.WPF.Panels.Parameterisation; using Sigma.Core.Monitors.WPF.Utils; +using Sigma.Core.Monitors.WPF.View.Parameterisation; using Sigma.Core.Training; +using Sigma.Core.Training.Hooks.Processors; using Sigma.Core.Training.Hooks.Reporters; using Sigma.Core.Training.Initialisers; using Sigma.Core.Training.Operators.Backends.NativeCpu; +using Sigma.Core.Training.Optimisers.Gradient; using Sigma.Core.Training.Optimisers.Gradient.Memory; using Sigma.Core.Utils; +using System; +using System.Threading; +using Sigma.Core.Data.Preprocessors.Adaptive; +using Sigma.Core.Monitors.WPF.Model.UI.Windows; +using Sigma.Core.Training.Hooks; +using Sigma.Core.Training.Hooks.Saviors; +using Sigma.Core.Training.Hooks.Stoppers; +using Sigma.Core.Training.Modifiers; namespace Sigma.Tests.Internals.WPF { - internal class Program - { - private const bool UI = true; - - private static void Main() - { - SigmaEnvironment.EnableLogging(); - SigmaEnvironment sigma = SigmaEnvironment.Create("Sigma-MNIST"); - - // create a new mnist trainer - ITrainer trainer = CreateMnistTrainer(sigma); - - // for the UI we have to activate more features - if (UI) - { - // create and attach a new UI framework - WPFMonitor gui = sigma.AddMonitor(new WPFMonitor("MNIST")); - - // create a tab - gui.AddTabs("Overview"); - - // access the window inside the ui thread - gui.WindowDispatcher(window => - { - // enable initialisation - window.IsInitializing = true; - - // add a panel that controls the learning process - window.TabControl["Overview"].AddCumulativePanel(new ControlPanel("Control", trainer)); - - // create an accuracy cost that updates every iteration - var cost = new TrainerChartPanel, double>("Cost", trainer, "optimiser.cost_total", TimeStep.Every(1, TimeScale.Iteration)); - // improve the chart performance - cost.Fast(); - - // add the newly created panel - window.TabControl["Overview"].AddCumulativePanel(cost); - - // finish initialisation - window.IsInitializing = false; - }); - - // the operators should not run instantly but when the user clicks play - sigma.StartOperatorsOnRun = false; - } - - sigma.Prepare(); - - sigma.Run(); - } - - /// - /// Create a MNIST trainer (writing recognition) will be added to an environemnt. - /// - /// The sigma environemnt this trainer will be assigned to. - /// The newly created trainer. - private static ITrainer CreateMnistTrainer(SigmaEnvironment sigma) - { - ByteRecordReader mnistImageReader = new ByteRecordReader(headerLengthBytes: 16, recordSizeBytes: 28 * 28, source: new CompressedSource(new MultiSource(new FileSource("train-images-idx3-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")))); - IRecordExtractor mnistImageExtractor = mnistImageReader.Extractor("inputs", new[] { 0L, 0L }, new[] { 28L, 28L }).Preprocess(new NormalisingPreprocessor(0, 255)); - - ByteRecordReader mnistTargetReader = new ByteRecordReader(headerLengthBytes: 8, recordSizeBytes: 1, source: new CompressedSource(new MultiSource(new FileSource("train-labels-idx1-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")))); - IRecordExtractor mnistTargetExtractor = mnistTargetReader.Extractor("targets", new[] { 0L }, new[] { 1L }).Preprocess(new OneHotPreprocessor(minValue: 0, maxValue: 9)); - - IDataset dataset = new Dataset("mnist-training", Dataset.BlockSizeAuto, mnistImageExtractor, mnistTargetExtractor); - ITrainer trainer = sigma.CreateTrainer("test"); - - trainer.Network = new Network - { - Architecture = InputLayer.Construct(28, 28) - + 2 * FullyConnectedLayer.Construct(28 * 28) - + FullyConnectedLayer.Construct(10) - + OutputLayer.Construct(10) - + SoftMaxCrossEntropyCostLayer.Construct() - }; - - trainer.TrainingDataIterator = new MinibatchIterator(8, dataset); - trainer.Optimiser = new AdagradOptimiser(baseLearningRate: 0.02); - trainer.Operator = new CpuSinglethreadedOperator(); - - trainer.AddInitialiser("*.weights", new GaussianInitialiser(standardDeviation: 0.05f)); - trainer.AddInitialiser("*.bias*", new GaussianInitialiser(standardDeviation: 0.01f, mean: 0.03f)); - - trainer.AddGlobalHook(new CurrentEpochIterationReporter(TimeStep.Every(1, TimeScale.Iteration))); - - return trainer; - } - } + + internal class Program + { + private const bool SampleMnist = true; + + private static void Main() + { + SigmaEnvironment.EnableLogging(); + SigmaEnvironment sigma = SigmaEnvironment.Create("sigma_demo"); + + // create a new mnist trainer + string name = SampleMnist ? "MNIST" : "IRIS"; + ITrainer trainer = SampleMnist ? CreateMnistTrainer(sigma) : CreateIrisTrainer(sigma); + + trainer.AddLocalHook(new MetricProcessorHook("network.layers.*.weights", (a, h) => h.Divide(h.Sum(a), a.Length), "shared.network_weights_average")); + trainer.AddLocalHook(new MetricProcessorHook("network.layers.*.weights", (a, h) => h.StandardDeviation(a), "shared.network_weights_stddev")); + trainer.AddLocalHook(new MetricProcessorHook("network.layers.*.biases", (a, h) => h.Divide(h.Sum(a), a.Length), "shared.network_biases_average")); + trainer.AddLocalHook(new MetricProcessorHook("network.layers.*.biases", (a, h) => h.StandardDeviation(a), "shared.network_biases_stddev")); + trainer.AddLocalHook(new MetricProcessorHook("optimiser.updates", (a, h) => h.Divide(h.Sum(a), a.Length), "shared.optimiser_updates_average")); + trainer.AddLocalHook(new MetricProcessorHook("optimiser.updates", (a, h) => h.StandardDeviation(a), "shared.optimiser_updates_stddev")); + + // create and attach a new UI framework + WPFMonitor gui = sigma.AddMonitor(new WPFMonitor(name, SampleMnist ? "de-DE" : "en-EN")); + gui.ColourManager.Dark = !SampleMnist; + + StatusBarLegendInfo iris = new StatusBarLegendInfo(name, MaterialColour.Blue); + StatusBarLegendInfo general = new StatusBarLegendInfo("General", MaterialColour.Grey); + gui.AddLegend(iris); + gui.AddLegend(general); + + // create a tab + gui.AddTabs("Overview", "Metrics", "Validation"); + + // access the window inside the ui thread + gui.WindowDispatcher(window => + { + // enable initialisation + window.IsInitializing = true; + + window.TabControl["Metrics"].GridSize = new GridSize(2, 4); + window.TabControl["Validation"].GridSize = new GridSize(1, 2); + + window.TabControl["Overview"].GridSize.Rows -= 1; + window.TabControl["Overview"].GridSize.Columns -= 1; + + // add a panel that controls the learning process + window.TabControl["Overview"].AddCumulativePanel(new ControlPanel("Control", trainer), legend: iris); + + ITimeStep reportTimeStep = SampleMnist ? TimeStep.Every(1, TimeScale.Iteration) : TimeStep.Every(10, TimeScale.Epoch); + + var cost1 = new TrainerChartPanel, double>("Cost / Epoch", trainer, "optimiser.cost_total", reportTimeStep); + cost1.Fast(); + //var cost2 = new TrainerChartPanel, double>("Cost / Epoch", trainer, "optimiser.cost_total", reportTimeStep); + //cost2.Fast(); + + var weightAverage = new TrainerChartPanel, double>("Mean of Weights / Epoch", trainer, "shared.network_weights_average", reportTimeStep, averageMode: true); + weightAverage.Fast(); + + var weightStddev = new TrainerChartPanel, double>("Standard Deviation of Weights / Epoch", trainer, "shared.network_weights_stddev", reportTimeStep, averageMode: true); + weightStddev.Fast(); + + var biasesAverage = new TrainerChartPanel, double>("Mean of Biases / Epoch", trainer, "shared.network_biases_average", reportTimeStep, averageMode: true); + biasesAverage.Fast(); + + var biasesStddev = new TrainerChartPanel, double>("Standard Deviation of Biases / Epoch", trainer, "shared.network_biases_stddev", reportTimeStep, averageMode: true); + biasesStddev.Fast(); + + var updateAverage = new TrainerChartPanel, double>("Mean of Parameter Updates / Epoch", trainer, "shared.optimiser_updates_average", reportTimeStep, averageMode: true); + updateAverage.Fast(); + + var updateStddev = new TrainerChartPanel, double>("Standard Deviation of Parameter Updates / Epoch", trainer, "shared.optimiser_updates_stddev", reportTimeStep, averageMode: true); + updateStddev.Fast(); + + var accuracy1 = new AccuracyPanel("Validation Accuracy", trainer, SampleMnist ? TimeStep.Every(1, TimeScale.Epoch) : reportTimeStep, null, 1, 2); + accuracy1.Fast(); + var accuracy2 = new AccuracyPanel("Validation Accuracy", trainer, SampleMnist ? TimeStep.Every(1, TimeScale.Epoch) : reportTimeStep, null, 1, 2); + accuracy2.Fast(); + + IRegistry regTest = new Registry(); + regTest.Add("test", DateTime.Now); + + var parameter = new ParameterPanel("Parameters", sigma, window); + parameter.Add("Time", typeof(DateTime), regTest, "test"); + + ValueSourceReporterHook valueHook = new ValueSourceReporterHook(TimeStep.Every(1, TimeScale.Epoch), "optimiser.cost_total"); + trainer.AddGlobalHook(valueHook); + sigma.SynchronisationHandler.AddSynchronisationSource(valueHook); + + var costBlock = (UserControlParameterVisualiser) parameter.Content.Add("Cost", typeof(double), trainer.Operator.Registry, "optimiser.cost_total"); + costBlock.AutoPollValues(trainer, TimeStep.Every(1, TimeScale.Epoch)); + + var learningBlock = (UserControlParameterVisualiser) parameter.Content.Add("Learning rate", typeof(double), trainer.Operator.Registry, "optimiser.learning_rate"); + learningBlock.AutoPollValues(trainer, TimeStep.Every(1, TimeScale.Epoch)); + + //trainer.AddGlobalHook(new RunningTimeReporter(TimeStep.Every(1, TimeScale.Epoch))); + + //var heeBlock = new SigmaTimeBlock(); + //heeBlock.AutoPollValues(trainer, TimeStep.Every(1, TimeScale.Epoch)); + //parameter.Content.Add(new Label { Content = "Cost" }, heeBlock, null, "optimiser.cost_total"); + + window.TabControl["Overview"].AddCumulativePanel(cost1, 1, 2, legend: iris); + window.TabControl["Overview"].AddCumulativePanel(parameter); + window.TabControl["Overview"].AddCumulativePanel(accuracy1, 1, 2, legend: iris); + + //window.TabControl["Metrics"].AddCumulativePanel(cost2, legend: iris); + window.TabControl["Metrics"].AddCumulativePanel(weightAverage, legend: iris); + window.TabControl["Metrics"].AddCumulativePanel(biasesAverage, legend: iris); + window.TabControl["Metrics"].AddCumulativePanel(updateAverage, legend: iris); + window.TabControl["Metrics"].AddCumulativePanel(accuracy2, legend: iris); + window.TabControl["Metrics"].AddCumulativePanel(weightStddev, legend: iris); + window.TabControl["Metrics"].AddCumulativePanel(biasesStddev, legend: iris); + window.TabControl["Metrics"].AddCumulativePanel(updateStddev, legend: iris); + + if (SampleMnist) + { + // TODO validation panel + } + + // finish initialisation + window.IsInitializing = false; + }); + + // the operators should not run instantly but when the user clicks play + sigma.StartOperatorsOnRun = false; + + sigma.Prepare(); + + sigma.Run(); + } + + private static ITrainer CreateIrisTrainer(SigmaEnvironment sigma) + { + var irisReader = new CsvRecordReader(new MultiSource(new FileSource("iris.data"), new UrlSource("http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"))); + IRecordExtractor irisExtractor = irisReader.Extractor("inputs", new[] { 0, 3 }, "targets", 4).AddValueMapping(4, "Iris-setosa", "Iris-versicolor", "Iris-virginica") + .Preprocess(new OneHotPreprocessor("targets", minValue: 0, maxValue: 2)) + .Preprocess(new AdaptiveNormalisingPreprocessor(minOutputValue: 0.0, maxOutputValue: 1.0)) + .Preprocess(new ShufflePreprocessor()); + + IDataset dataset = new ExtractedDataset("iris", ExtractedDataset.BlockSizeAuto, false, irisExtractor); + + ITrainer trainer = sigma.CreateTrainer("test"); + + trainer.Network = new Network(); + trainer.Network.Architecture = InputLayer.Construct(4) + + FullyConnectedLayer.Construct(4) + + FullyConnectedLayer.Construct(24) + + FullyConnectedLayer.Construct(3) + + OutputLayer.Construct(3) + + SoftMaxCrossEntropyCostLayer.Construct(); + + trainer.TrainingDataIterator = new MinibatchIterator(10, dataset); + trainer.AddNamedDataIterator("validation", new UndividedIterator(dataset)); + trainer.Optimiser = new AdadeltaOptimiser(decayRate: 0.9); + trainer.Operator = new CpuSinglethreadedOperator(); + + trainer.AddInitialiser("*.weights", new GaussianInitialiser(standardDeviation: 0.3)); + trainer.AddInitialiser("*.bias*", new GaussianInitialiser(standardDeviation: 0.1)); + + return trainer; + } + + /// + /// Create a MNIST trainer (writing recognition) will be added to an environemnt. + /// + /// The sigma environemnt this trainer will be assigned to. + /// The newly created trainer. + private static ITrainer CreateMnistTrainer(SigmaEnvironment sigma) + { + ByteRecordReader mnistImageReader = new ByteRecordReader(headerLengthBytes: 16, recordSizeBytes: 28 * 28, source: new CompressedSource(new MultiSource(new FileSource("train-images-idx3-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")))); + IRecordExtractor mnistImageExtractor = mnistImageReader.Extractor("inputs", new[] { 0L, 0L }, new[] { 28L, 28L }).Preprocess(new NormalisingPreprocessor(0, 255)); + + ByteRecordReader mnistTargetReader = new ByteRecordReader(headerLengthBytes: 8, recordSizeBytes: 1, source: new CompressedSource(new MultiSource(new FileSource("train-labels-idx1-ubyte.gz"), new UrlSource("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")))); + IRecordExtractor mnistTargetExtractor = mnistTargetReader.Extractor("targets", new[] { 0L }, new[] { 1L }).Preprocess(new OneHotPreprocessor(minValue: 0, maxValue: 9)); + + IDataset dataset = new ExtractedDataset("mnist", ExtractedDataset.BlockSizeAuto, false, mnistImageExtractor, mnistTargetExtractor); + ITrainer trainer = sigma.CreateTrainer("test"); + + trainer.Network = new Network + { + Architecture = InputLayer.Construct(28, 28) + + FullyConnectedLayer.Construct(28 * 28) + + FullyConnectedLayer.Construct(10) + + OutputLayer.Construct(10) + + SoftMaxCrossEntropyCostLayer.Construct() + }; + + trainer.TrainingDataIterator = new MinibatchIterator(100, dataset); + trainer.AddNamedDataIterator("validation", new UndividedIterator(dataset)); + trainer.Optimiser = new AdadeltaOptimiser(decayRate: 0.9); + trainer.Operator = new CpuSinglethreadedOperator(); + + trainer.AddInitialiser("*.weights", new GaussianInitialiser(standardDeviation: 0.1f)); + trainer.AddInitialiser("*.bias*", new GaussianInitialiser(standardDeviation: 0.1f, mean: 0.03f)); + + return trainer; + } + } } diff --git a/Sigma.Tests/Data/Datasets/TestDatasetBlockwiseSlice.cs b/Sigma.Tests/Data/Datasets/TestDatasetBlockwiseSlice.cs index 5f2cbec4..922e748b 100644 --- a/Sigma.Tests/Data/Datasets/TestDatasetBlockwiseSlice.cs +++ b/Sigma.Tests/Data/Datasets/TestDatasetBlockwiseSlice.cs @@ -48,7 +48,7 @@ public void TestDatasetBlockwiseSliceCreate() CreateCsvTempFile(filename); CsvRecordExtractor extractor = new CsvRecordReader(new FileSource(filename)).Extractor("inputs", 1, 2, "targets", 3); - Dataset dataset = new Dataset("name", Dataset.BlockSizeAuto, extractor); + ExtractedDataset dataset = new ExtractedDataset("name", ExtractedDataset.BlockSizeAuto, extractor); Assert.Throws(() => new DatasetBlockwiseSlice(null, 0, 0, 1)); Assert.Throws(() => new DatasetBlockwiseSlice(dataset, 0, 0, 0)); @@ -82,7 +82,7 @@ public void TestDatasetBlockwiseSliceFetch() CreateCsvTempFile(filename); CsvRecordExtractor extractor = new CsvRecordReader(new FileSource(filename)).Extractor("inputs", 0, "targets", 3); - Dataset dataset = new Dataset("name", 1, extractor); + ExtractedDataset dataset = new ExtractedDataset("name", 1, extractor); DatasetBlockwiseSlice slice = new DatasetBlockwiseSlice(dataset, 1, 2, 3); Assert.AreEqual(new float[] { 4.9f }, slice.FetchBlock(0, new CpuFloat32Handler())["inputs"].GetDataAs().GetValuesArrayAs(0, 1)); diff --git a/Sigma.Tests/Data/Datasets/TestDatasetRecordwiseSlice.cs b/Sigma.Tests/Data/Datasets/TestDatasetRecordwiseSlice.cs index 929f3a32..dce8c9fe 100644 --- a/Sigma.Tests/Data/Datasets/TestDatasetRecordwiseSlice.cs +++ b/Sigma.Tests/Data/Datasets/TestDatasetRecordwiseSlice.cs @@ -48,7 +48,7 @@ public void TestDatasetRecordwiseSliceCreate() CreateCsvTempFile(filename); CsvRecordExtractor extractor = new CsvRecordReader(new FileSource(filename)).Extractor("inputs", 1, 2, "targets", 3); - Dataset dataset = new Dataset("name", Dataset.BlockSizeAuto, extractor); + ExtractedDataset dataset = new ExtractedDataset("name", ExtractedDataset.BlockSizeAuto, extractor); Assert.Throws(() => new DatasetRecordwiseSlice(null, 0.0, 1.0)); Assert.Throws(() => new DatasetRecordwiseSlice(dataset, -0.2, 1.0)); @@ -74,7 +74,7 @@ public void TestDatsetRecordwiseSliceFetch() CreateCsvTempFile(filename); CsvRecordExtractor extractor = new CsvRecordReader(new FileSource(filename)).Extractor("inputs", 0, "targets", 3); - Dataset dataset = new Dataset("name", 3, extractor); + ExtractedDataset dataset = new ExtractedDataset("name", 3, extractor); DatasetRecordwiseSlice slice = new DatasetRecordwiseSlice(dataset, 0.1, 0.6); Assert.AreEqual(new float[] { 5.1f, 4.9f }, slice.FetchBlock(0, new CpuFloat32Handler())["inputs"].GetDataAs().GetValuesArrayAs(0, 2)); diff --git a/Sigma.Tests/Data/Datasets/TestDataset.cs b/Sigma.Tests/Data/Datasets/TestExtractedDataset.cs similarity index 80% rename from Sigma.Tests/Data/Datasets/TestDataset.cs rename to Sigma.Tests/Data/Datasets/TestExtractedDataset.cs index dfe5e065..c63e3d44 100644 --- a/Sigma.Tests/Data/Datasets/TestDataset.cs +++ b/Sigma.Tests/Data/Datasets/TestExtractedDataset.cs @@ -21,7 +21,7 @@ For full license see LICENSE in the root directory of this project. namespace Sigma.Tests.Data.Datasets { - public class TestDataset : BaseLocaleTest + public class TestExtractedDataset : BaseLocaleTest { private static void RedirectGlobalsToTempPath() { @@ -53,20 +53,20 @@ public void TestDatasetCreate() CsvRecordExtractor extractor = new CsvRecordReader(new FileSource(filename)).Extractor("inputs", 1, 2, "targets", 3); CsvRecordExtractor clashingExtractor = new CsvRecordReader(new FileSource(filename)).Extractor("inputs", 1, 2); - Assert.Throws(() => new Dataset(null, null)); - Assert.Throws(() => new Dataset("name", null)); - Assert.Throws(() => new Dataset("name", 10, null)); - Assert.Throws(() => new Dataset("name", 10, null, extractor)); + Assert.Throws(() => new ExtractedDataset(null, null)); + Assert.Throws(() => new ExtractedDataset("name", null)); + Assert.Throws(() => new ExtractedDataset("name", 10, null)); + Assert.Throws(() => new ExtractedDataset("name", 10, null, extractor)); - Assert.Throws(() => new Dataset("name", 10)); - Assert.Throws(() => new Dataset("name", -3, extractor)); - Assert.Throws(() => new Dataset("name")); - Assert.Throws(() => new Dataset("name", extractor, clashingExtractor)); + Assert.Throws(() => new ExtractedDataset("name", 10)); + Assert.Throws(() => new ExtractedDataset("name", -3, extractor)); + Assert.Throws(() => new ExtractedDataset("name")); + Assert.Throws(() => new ExtractedDataset("name", extractor, clashingExtractor)); - Assert.AreEqual("name", new Dataset("name", extractor).Name); + Assert.AreEqual("name", new ExtractedDataset("name", extractor).Name); - Assert.Greater(new Dataset("name", extractor).TargetBlockSizeRecords, 0); - Assert.Greater(new Dataset("name", Dataset.BlockSizeAuto, extractor).TargetBlockSizeRecords, 0); + Assert.Greater(new ExtractedDataset("name", extractor).TargetBlockSizeRecords, 0); + Assert.Greater(new ExtractedDataset("name", ExtractedDataset.BlockSizeAuto, extractor).TargetBlockSizeRecords, 0); DeleteTempFile(filename); } @@ -81,7 +81,7 @@ public void TestDatasetFetchBlockSequential() CreateCsvTempFile(filename); CsvRecordExtractor extractor = new CsvRecordReader(new FileSource(filename, Path.GetTempPath())).Extractor("inputs", 1, 2, "targets", 3); - Dataset dataset = new Dataset(name: "name", blockSizeRecords: 1, recordExtractors: extractor); + ExtractedDataset dataset = new ExtractedDataset(name: "name", blockSizeRecords: 1, recordExtractors: extractor); CpuFloat32Handler handler = new CpuFloat32Handler(); IDictionary namedArrays = dataset.FetchBlock(0, handler, false); @@ -122,7 +122,7 @@ public async Task TestDatasetFetchAsync() CreateCsvTempFile(filename); CsvRecordExtractor extractor = new CsvRecordReader(new FileSource(filename, Path.GetTempPath())).Extractor("inputs", 1, 2, "targets", 3); - Dataset dataset = new Dataset(name: "name", blockSizeRecords: 1, recordExtractors: extractor); + ExtractedDataset dataset = new ExtractedDataset(name: "name", blockSizeRecords: 1, recordExtractors: extractor); CpuFloat32Handler handler = new CpuFloat32Handler(); var block0 = dataset.FetchBlockAsync(0, handler); @@ -160,7 +160,7 @@ public void TestDatasetFreeBlockSequential() CreateCsvTempFile(filename); CsvRecordExtractor extractor = new CsvRecordReader(new FileSource(filename, Path.GetTempPath())).Extractor("inputs", 1, 2, "targets", 3); - Dataset dataset = new Dataset(name: "name", blockSizeRecords: 1, recordExtractors: extractor); + ExtractedDataset dataset = new ExtractedDataset(name: "name", blockSizeRecords: 1, recordExtractors: extractor); CpuFloat32Handler handler = new CpuFloat32Handler(); dataset.FetchBlock(0, handler, false); diff --git a/Sigma.Tests/Data/Iterators/TestMinibatchIterator.cs b/Sigma.Tests/Data/Iterators/TestMinibatchIterator.cs index 57a39c2c..383bf1a5 100644 --- a/Sigma.Tests/Data/Iterators/TestMinibatchIterator.cs +++ b/Sigma.Tests/Data/Iterators/TestMinibatchIterator.cs @@ -45,7 +45,7 @@ public void TestMinibatchIteratorCreate() FileSource source = new FileSource(filename, Path.GetTempPath()); CsvRecordExtractor extractor = (CsvRecordExtractor) new CsvRecordReader(source).Extractor(new CsvRecordExtractor(new Dictionary { ["inputs"] = new[] { new[] { 0 } } })); - Dataset dataset = new Dataset("test", 1, new DiskCacheProvider(Path.GetTempPath() + "/" + nameof(TestMinibatchIteratorYield)), true, extractor); + ExtractedDataset dataset = new ExtractedDataset("test", 1, new DiskCacheProvider(Path.GetTempPath() + "/" + nameof(TestMinibatchIteratorYield)), true, extractor); Assert.Throws(() => new MinibatchIterator(-3, dataset)); Assert.Throws(() => new MinibatchIterator(1, null)); @@ -76,7 +76,7 @@ public void TestMinibatchIteratorYield(int minibatchSize) FileSource source = new FileSource(filename, Path.GetTempPath()); CsvRecordExtractor extractor = (CsvRecordExtractor) new CsvRecordReader(source).Extractor(new CsvRecordExtractor(new Dictionary { ["inputs"] = new[] { new[] { 0 } } })); - Dataset dataset = new Dataset("test", 1, new DiskCacheProvider(Path.GetTempPath() + "/" + nameof(TestMinibatchIteratorYield)), true, extractor); + ExtractedDataset dataset = new ExtractedDataset("test", 1, new DiskCacheProvider(Path.GetTempPath() + "/" + nameof(TestMinibatchIteratorYield)), true, extractor); MinibatchIterator iterator = new MinibatchIterator(minibatchSize, dataset); IComputationHandler handler = new CpuFloat32Handler(); SigmaEnvironment sigma = SigmaEnvironment.Create("test"); diff --git a/Sigma.Tests/Data/Iterators/TestUndividedIterator.cs b/Sigma.Tests/Data/Iterators/TestUndividedIterator.cs index cdc1fb62..199c1a35 100644 --- a/Sigma.Tests/Data/Iterators/TestUndividedIterator.cs +++ b/Sigma.Tests/Data/Iterators/TestUndividedIterator.cs @@ -44,7 +44,7 @@ public void TestUndividedIteratorCreate() FileSource source = new FileSource(filename, Path.GetTempPath()); CsvRecordExtractor extractor = (CsvRecordExtractor) new CsvRecordReader(source).Extractor(new CsvRecordExtractor(new Dictionary { ["inputs"] = new[] { new[] { 0 } } })); - Dataset dataset = new Dataset("test", 1, new DiskCacheProvider(Path.GetTempPath() + "/" + nameof(TestUndividedIteratorCreate)), true, extractor); + ExtractedDataset dataset = new ExtractedDataset("test", 1, new DiskCacheProvider(Path.GetTempPath() + "/" + nameof(TestUndividedIteratorCreate)), true, extractor); Assert.Throws(() => new UndividedIterator(null)); @@ -64,7 +64,7 @@ public void TestUndividedIteratorYield() FileSource source = new FileSource(filename, Path.GetTempPath()); CsvRecordExtractor extractor = (CsvRecordExtractor) new CsvRecordReader(source).Extractor(new CsvRecordExtractor(new Dictionary { ["inputs"] = new[] { new[] { 0 } } })); - Dataset dataset = new Dataset("test", 2, new DiskCacheProvider(Path.GetTempPath() + "/" + nameof(TestUndividedIteratorCreate)), true, extractor); + ExtractedDataset dataset = new ExtractedDataset("test", 2, new DiskCacheProvider(Path.GetTempPath() + "/" + nameof(TestUndividedIteratorCreate)), true, extractor); UndividedIterator iterator = new UndividedIterator(dataset); SigmaEnvironment sigma = SigmaEnvironment.Create("test"); IComputationHandler handler = new CpuFloat32Handler(); diff --git a/Sigma.Tests/Data/Iterators/TestUnifiedIterator.cs b/Sigma.Tests/Data/Iterators/TestUnifiedIterator.cs index bd77f73f..5cbb5934 100644 --- a/Sigma.Tests/Data/Iterators/TestUnifiedIterator.cs +++ b/Sigma.Tests/Data/Iterators/TestUnifiedIterator.cs @@ -52,7 +52,7 @@ public void TestUnifiedIteratorYield() FileSource source = new FileSource(filename, Path.GetTempPath()); CsvRecordExtractor extractor = (CsvRecordExtractor) new CsvRecordReader(source).Extractor(new CsvRecordExtractor(new Dictionary { ["inputs"] = new[] { new[] { 0 } } })); - Dataset dataset = new Dataset("test", 2, new DiskCacheProvider(Path.GetTempPath() + "/" + nameof(TestUnifiedIteratorYield)), true, extractor); + ExtractedDataset dataset = new ExtractedDataset("test", 2, new DiskCacheProvider(Path.GetTempPath() + "/" + nameof(TestUnifiedIteratorYield)), true, extractor); UnifiedIterator iterator = new UnifiedIterator(dataset); SigmaEnvironment sigma = SigmaEnvironment.Create("test"); IComputationHandler handler = new CpuFloat32Handler(); diff --git a/Sigma.Tests/Sigma.Tests.csproj b/Sigma.Tests/Sigma.Tests.csproj index 936769f7..9e5b2e94 100644 --- a/Sigma.Tests/Sigma.Tests.csproj +++ b/Sigma.Tests/Sigma.Tests.csproj @@ -123,7 +123,7 @@ - + diff --git a/Sigma.Tests/Training/Initialisers/TestConstantValueInitialiser.cs b/Sigma.Tests/Training/Initialisers/TestConstantValueInitialiser.cs index 9edbc610..20acf8ea 100644 --- a/Sigma.Tests/Training/Initialisers/TestConstantValueInitialiser.cs +++ b/Sigma.Tests/Training/Initialisers/TestConstantValueInitialiser.cs @@ -13,6 +13,7 @@ For full license see LICENSE in the root directory of this project. using System; using Sigma.Core.Handlers.Backends.SigmaDiff.NativeCpu; using Sigma.Core.MathAbstract.Backends.SigmaDiff; +using Sigma.Core.MathAbstract.Backends.SigmaDiff.NativeCpu; namespace Sigma.Tests.Training.Initialisers { @@ -23,9 +24,10 @@ public void TestConstantValueInitialiserInitialise() { ConstantValueInitialiser initialiser = new ConstantValueInitialiser(2.0); - INDArray array = new ADNDArray(2, 1, 2, 2); IComputationHandler handler = new CpuFloat32Handler(); - Random random = new Random(); + INDArray array = handler.NDArray(2L, 1L, 2L, 2L); + + Random random = new Random(); Assert.Throws(() => initialiser.Initialise((INDArray) null, handler, random)); Assert.Throws(() => initialiser.Initialise((INumber) null, handler, random)); diff --git a/Sigma.Tests/Training/MockTrainer.cs b/Sigma.Tests/Training/MockTrainer.cs index a9986d48..5adea2c0 100644 --- a/Sigma.Tests/Training/MockTrainer.cs +++ b/Sigma.Tests/Training/MockTrainer.cs @@ -40,7 +40,7 @@ protected MockTrainer(string name) : base(name) extractor.SectionNames = new[] {"targets", "inputs"}; extractor.Reader = new MockRecordReader(); Sigma = SigmaEnvironment.GetOrCreate("testificate-mocktrainer"); - TrainingDataIterator = new UndividedIterator(new Dataset("testificate", extractor)); + TrainingDataIterator = new UndividedIterator(new ExtractedDataset("testificate", extractor)); } internal class MockRecordReader : IRecordReader diff --git a/Sigma.sln.DotSettings b/Sigma.sln.DotSettings index 8b5efca9..58b7cff8 100644 --- a/Sigma.sln.DotSettings +++ b/Sigma.sln.DotSettings @@ -58,4 +58,5 @@ True True True - True \ No newline at end of file + True + True \ No newline at end of file