diff --git a/README.md b/README.md index 51be3318..64da8cf1 100644 --- a/README.md +++ b/README.md @@ -51,10 +51,12 @@ The naming scheme should be obvious and must be followed. Compile with make. Converting the data from ntuples =========== -``convertFromSource.py -i -o -c TrainData_NanoML`` +``convertFromSource.py -i -o -c TrainData_NanoML`` The conversion rule itself is located here: ``modules/datastructures/TrainData_NanoML.py`` +Other conversion rules / data structures can also be defined. For maximum compatibility, it is advised to follow the NanoML example w.r.t. the final outputs. + The training files (see next section) usually also contain a comment in the beginning pointing to the latest data set at CERN and flatiron. Standard training and inference @@ -68,7 +70,7 @@ cd Train Look at the first lines of the file `std_training.py` containing a short description and where to find the dataset compatible with that training file. Then execute the following command to run a training. ``` -python3 std_training.py /training_data.djcdc +python3 baseline_training.py /training_data.djcdc ``` Please notice that the standard configuration might or might not include writing the printout to a file in the training output directory. diff --git a/Train/baseline_training.py b/Train/baseline_training.py new file mode 100644 index 00000000..5933c51b --- /dev/null +++ b/Train/baseline_training.py @@ -0,0 +1,338 @@ +''' + +Compatible with the dataset here: +/eos/home-j/jkiesele/ML4Reco/Gun20Part_NewMerge/train + +On flatiron: +/mnt/ceph/users/jkieseler/HGCalML_data/Gun20Part_NewMerge/train + +not compatible with datasets before end of Jan 2022 + +''' + +import tensorflow as tf + +from tensorflow.keras.layers import Dense, Concatenate + +from DeepJetCore.DJCLayers import StopGradient + +from Layers import RaggedGlobalExchange, DistanceWeightedMessagePassing, DictModel +from Layers import RaggedGravNet, ScaledGooeyBatchNorm2 +from Regularizers import AverageDistanceRegularizer +from LossLayers import LLFullObjectCondensation +from DebugLayers import PlotCoordinates + +from model_blocks import condition_input, extent_coords_if_needed, create_outputs, re_integrate_to_full_hits + +from callbacks import plotClusterSummary + +from DeepJetCore.training.DeepJet_callbacks import simpleMetricsCallback + + +#loss options: +loss_options={ + # here and in the following energy = momentum + 'energy_loss_weight': 0., + 'q_min': 1., + # addition to original OC, adds average position for clusterin + # usually 0.5 is a reasonable value to break degeneracies + # and keep training smooth enough + 'use_average_cc_pos': 0.5, + 'classification_loss_weight':0.0, + 'position_loss_weight':0., + 'timing_loss_weight':0., + 'beta_loss_scale':1., + # these weights will downweight low energies, for a + # training sample with a good energy distribution, + # this won't be needed. + 'use_energy_weights': False, + # this is the standard repulsive hinge loss from the paper + 'implementation': 'hinge' + } + + +# elu behaves well, likely fine +dense_activation='elu' + +# record internal metrics every N batches +record_frequency=10 +# plot every M times, metrics were recorded. In other words, +# plotting will happen every M*N batches +plotfrequency=50 + +learningrate = 1e-4 + +# this is the maximum number of hits (points) per batch, +# not the number of events (samples). This is safer w.r.t. +# memory +nbatch = 10000 + +#iterations of gravnet blocks +n_neighbours=[64,64] + +# 3 is a bit low but nice in the beginning since it can be plotted +n_cluster_space_coordinates = 3 +n_gravnet_dims = 3 + + +def gravnet_model(Inputs, + td, + debug_outdir=None, + plot_debug_every=record_frequency*plotfrequency, + ): + #################################################################################### + ##################### Input processing, no need to change much here ################ + #################################################################################### + + input_list = td.interpretAllModelInputs(Inputs,returndict=True) + input_list = condition_input(input_list, no_scaling=True) + + #just for info what's available, prints once + print('available inputs',[k for k in input_list.keys()]) + + rs = input_list['row_splits'] + t_idx = input_list['t_idx'] + energy = input_list['rechit_energy'] + c_coords = input_list['coords'] + + ## build inputs + + x_in = Concatenate()([input_list['coords'], + input_list['features']]) + + x_in = ScaledGooeyBatchNorm2( + fluidity_decay=0.1 #freeze out quickly, just to get good input preprocessing + )(x_in) + + x = x_in + + c_coords = ScaledGooeyBatchNorm2( + fluidity_decay=0.1 #same here + )(c_coords) + + + #################################################################################### + ##################### now the actual model goes below ############################## + #################################################################################### + + # output of each iteration will be concatenated + allfeat = [] + + # extend coordinates already here if needed, just as a good starting point + c_coords = extent_coords_if_needed(c_coords, x, n_gravnet_dims) + + for i in range(len(n_neighbours)): + + # derive new coordinates for clustering + x = RaggedGlobalExchange()([x, rs]) + + x = Dense(64,activation=dense_activation)(x) + x = Dense(64,activation=dense_activation)(x) + x = Dense(64,activation=dense_activation)(x) + x = Concatenate()([c_coords,x]) #give a good starting point + x = ScaledGooeyBatchNorm2()(x) + + xgn, gncoords, gnnidx, gndist = RaggedGravNet(n_neighbours=n_neighbours[i], + n_dimensions=n_gravnet_dims, + n_propagate=64, #this is the number of features that are exchanged + n_filters=64, #output dense + feature_activation = 'elu', + )([x, rs]) + + x = Concatenate()([x,xgn]) + + # mostly to record average distances etc. can be used to force coordinates + # to be within reasonable range (but usually not needed) + gndist = AverageDistanceRegularizer(strength=1e-6, + record_metrics=True + )(gndist) + + #for information / debugging, can also be safely removed + gncoords = PlotCoordinates(plot_every = plot_debug_every, outdir = debug_outdir, + name='gn_coords_'+str(i))([gncoords, + energy, + t_idx, + rs]) + # we have to pass them downwards, otherwise the layer above gets optimised away + # but we don't want the gradient to be disturbed, so it gets stopped + gncoords = StopGradient()(gncoords) + x = Concatenate()([gncoords,x]) + + # this repeats the distance weighted message passing step from gravnet + # on the same graph topology + x = DistanceWeightedMessagePassing([64,64], + activation=dense_activation + )([x,gnnidx,gndist]) + + x = ScaledGooeyBatchNorm2()(x) + + x = Dense(64,activation=dense_activation)(x) + x = Dense(64,activation=dense_activation)(x) + x = Dense(64,activation=dense_activation)(x) + + x = ScaledGooeyBatchNorm2()(x) + + allfeat.append(x) + + + + x = Concatenate()([c_coords]+allfeat)#gives a prior to the clustering coords + #create one global feature vector + xg = Dense(512,activation=dense_activation,name='glob_dense_'+str(i))(x) + x = RaggedGlobalExchange()([xg, rs]) + x = Concatenate()([x,xg]) + # last part of network + x = Dense(64,activation=dense_activation)(x) + x = ScaledGooeyBatchNorm2()(x) + x = Dense(64,activation=dense_activation)(x) + x = ScaledGooeyBatchNorm2()(x) + x = Dense(64,activation=dense_activation)(x) + x = ScaledGooeyBatchNorm2()(x) + + + ####################################################################### + ########### the part below should remain almost unchanged ############# + ########### of course with the exception of the OC loss ############# + ########### weights ############# + ####################################################################### + + #use a standard batch norm at the last stage + + + pred_beta, pred_ccoords, pred_dist,\ + pred_energy_corr, pred_energy_low_quantile, pred_energy_high_quantile,\ + pred_pos, pred_time, pred_time_unc, pred_id = create_outputs(x, n_ccoords=n_cluster_space_coordinates) + + # loss + pred_beta = LLFullObjectCondensation(scale=1., + record_metrics=True, + print_loss=True, + name="FullOCLoss", + **loss_options + )( # oc output and payload + [pred_beta, pred_ccoords, pred_dist, + pred_energy_corr,pred_energy_low_quantile,pred_energy_high_quantile, + pred_pos, pred_time, pred_time_unc, + pred_id] + + [energy]+ + # truth information + [input_list['t_idx'] , + input_list['t_energy'] , + input_list['t_pos'] , + input_list['t_time'] , + input_list['t_pid'] , + input_list['t_spectator_weight'], + input_list['t_fully_contained'], + input_list['t_rec_energy'], + input_list['t_is_unique'], + input_list['row_splits']]) + + # fast feedback + pred_ccoords = PlotCoordinates(plot_every=plot_debug_every, outdir = debug_outdir, + name='condensation_coords')([pred_ccoords, pred_beta,input_list['t_idx'], + rs]) + + # just to have a defined output, only adds names + model_outputs = re_integrate_to_full_hits( + input_list, + pred_ccoords, + pred_beta, + pred_energy_corr, + pred_energy_low_quantile, + pred_energy_high_quantile, + pred_pos, + pred_time, + pred_id, + pred_dist + ) + + return DictModel(inputs=Inputs, outputs=model_outputs) + + + +import training_base_hgcal +train = training_base_hgcal.HGCalTraining() + +if not train.modelSet(): + train.setModel(gravnet_model, + td=train.train_data.dataclass(), + debug_outdir=train.outputDir+'/intplots') + + train.setCustomOptimizer(tf.keras.optimizers.Nadam(clipnorm=1.,epsilon=1e-2)) + # + train.compileModel(learningrate=1e-4) + + train.keras_model.summary() + + +verbosity = 2 +import os + +publishpath = None #this can be an ssh reachable path (be careful: needs tokens / keypairs) + +# establish callbacks + + +cb = [ + simpleMetricsCallback( + output_file=train.outputDir+'/metrics.html', + record_frequency= record_frequency, + plot_frequency = plotfrequency, + select_metrics='FullOCLoss_*loss', + publish=publishpath #no additional directory here (scp cannot create one) + ), + + simpleMetricsCallback( + output_file=train.outputDir+'/latent_space_metrics.html', + record_frequency= record_frequency, + plot_frequency = plotfrequency, + select_metrics='average_distance_*', + publish=publishpath + ), + + + simpleMetricsCallback( + output_file=train.outputDir+'/val_metrics.html', + call_on_epoch=True, + select_metrics='val_*', + publish=publishpath #no additional directory here (scp cannot create one) + ), + + + + + ] + + +cb += [ + plotClusterSummary( + outputfile=train.outputDir + "/clustering/", + samplefile=train.val_data.getSamplePath(train.val_data.samples[0]), + after_n_batches=200 + ) + ] + +#cb=[] + +train.change_learning_rate(learningrate) + +model, history = train.trainModel(nepochs=3, + batchsize=nbatch, + additional_callbacks=cb) + +print("freeze BN") +# Note the submodel here its not just train.keras_model +#for l in train.keras_model.layers: +# if 'FullOCLoss' in l.name: +# l.q_min/=2. + +train.change_learning_rate(learningrate/2.) + + +model, history = train.trainModel(nepochs=121, + batchsize=nbatch, + additional_callbacks=cb) + + + + diff --git a/Train/cheplike_training.py b/Train/cheplike_training.py index 53ccd52c..9ec2f252 100644 --- a/Train/cheplike_training.py +++ b/Train/cheplike_training.py @@ -53,6 +53,13 @@ from GravNetLayersRagged import CastRowSplits +### from graph / pointcloud pooling + +from GraphCondensationLayers import point_pool, point_scatter + + +### + import globals if False: #for testing @@ -68,14 +75,14 @@ #loss options: loss_options={ - 'energy_loss_weight': .25, + 'energy_loss_weight': .0, 'q_min': 1.5, 'use_average_cc_pos': 0.1, 'classification_loss_weight':0.0, - 'too_much_beta_scale': 1e-5 , + 'too_much_beta_scale': 0. , 'position_loss_weight':1e-5, 'timing_loss_weight':0.1, - 'beta_loss_scale':2., + 'beta_loss_scale':1., 'beta_push': 0#0.01 #push betas gently up at low values to not lose the gradients } @@ -87,7 +94,7 @@ plotfrequency=50 #plots every 1k batches learningrate = 1e-6 -nbatch = 100000 +nbatch = 20000 if globals.acc_ops_use_tf_gradients: #for tf gradients the memory is limited nbatch = 60000 @@ -110,13 +117,8 @@ def gravnet_model(Inputs, is_preselected = isinstance(td, TrainData_PreselectionNanoML) pre_selection = td.interpretAllModelInputs(Inputs,returndict=True) - - #can be loaded - or use pre-selected dataset (to be made) - if not is_preselected: - pre_selection = pre_selection_model(pre_selection,trainable=False,pass_through=False) - else: - pre_selection['row_splits'] = CastRowSplits()(pre_selection['row_splits']) - print(">> preselected dataset will omit pre-selection step") + + pre_selection = pre_selection_model(pre_selection,trainable=False,pass_through=True) #just for info what's available print('available pre-selection outputs',[k for k in pre_selection.keys()]) @@ -133,7 +135,7 @@ def gravnet_model(Inputs, c_coords = pre_selection['coords']#pre-clustered coordinates t_idx = pre_selection['t_idx'] - #################################################################################### + ################################################################################# ##################### now the actual model goes below ############################## #################################################################################### @@ -194,6 +196,42 @@ def gravnet_model(Inputs, x = ScaledGooeyBatchNorm2()(x) + allgt = [] + prs = rs + od = { + 'x': x, + 't_spectator_weight': pre_selection['t_spectator_weight'], + 't_idx': pre_selection['t_idx'], + 'is_track': pre_selection['is_track'] + } + t, od, prs = point_pool(od, prs, name="p_pool_a_"+str(i)) + od['x'],_,_,_ = RaggedGravNet(n_neighbours=n_neighbours[i], + name='gn_pooled_a_'+str(i), + n_dimensions=n_dims, + n_filters=64, + n_propagate=64, + record_metrics=True, + coord_initialiser_noise=1e-2, + use_approximate_knn=False #weird issue with that for now + )([od['x'], prs]) + allgt.append(t) + + t, od, prs = point_pool(od, prs, name="p_pool_b_"+str(i)) + od['x'],_,_,_ = RaggedGravNet(n_neighbours=n_neighbours[i], + name='gn_pooled_b_'+str(i), + n_dimensions=n_dims, + n_filters=64, + n_propagate=64, + record_metrics=True, + coord_initialiser_noise=1e-2, + use_approximate_knn=False #weird issue with that for now + )([od['x'], prs]) + allgt.append(t) + + + xp = point_scatter(od['x'], allgt, name = 'p_scatter_'+str(i)) + x = Concatenate()([x,xp]) + x = ScaledGooeyBatchNorm2()(x) allfeat.append(x) @@ -417,9 +455,7 @@ def gravnet_model(Inputs, l.q_min/=2. train.change_learning_rate(learningrate/2.) -nbatch = 160000 -if globals.acc_ops_use_tf_gradients: #for tf gradients the memory is limited - nbatch = 60000 + model, history = train.trainModel(nepochs=121, batchsize=nbatch, diff --git a/Train/pre_condensation_training.py b/Train/pre_condensation_training.py deleted file mode 100644 index be2bc4de..00000000 --- a/Train/pre_condensation_training.py +++ /dev/null @@ -1,180 +0,0 @@ -''' - - -Compatible with the dataset here: -/eos/home-j/jkiesele/ML4Reco/Gun20Part_NewMerge/train - -On flatiron: -/mnt/ceph/users/jkieseler/HGCalML_data/Gun20Part_NewMerge/train - -not compatible with datasets before end of Jan 2022 - -''' - -import tensorflow as tf -# from K import Layer - -from datastructures import TrainData_NanoML - -#from tensorflow.keras import Model -from Layers import DictModel - -from model_blocks import pre_condensation_model, mini_pre_condensation_model - -K=12 #12 - -plot_frequency= 20 # 150 #150 # 1000 #every 20 minutes approx -record_frequency = 3 - -def pretrain_model(Inputs, - td, - debugplots_after=record_frequency*plot_frequency, #10 minutes: ~600 - debug_outdir=None, - publishpath=None): - - orig_inputs = td.interpretAllModelInputs(Inputs,returndict=True) - - presel = mini_pre_condensation_model(orig_inputs, - record_metrics=True, - trainable=True, - t_d=0.5, # just starting point - t_b=0.6, # just starting point - q_min=1., - purity_target=0.96, - condensation_mode = 'std', # std, precond, pushpull, simpleknn - noise_threshold=0.15, - print_batch_time=False, - condensate=True, - cluster_dims = 3, - cleaning_threshold=0.5, - debug_outdir=debug_outdir, - debugplots_after=debugplots_after, - publishpath=publishpath - ) - presel.pop('noise_backscatter') - return DictModel(inputs=Inputs, outputs=presel) - -import training_base_hgcal -train = training_base_hgcal.HGCalTraining() - -publishpath = "jkiesele@lxplus.cern.ch:~/Cernbox/www/files/temp/Sept2022/" -publishpath += [d for d in train.outputDir.split('/') if len(d)][-1] - -print('will attempt to publish to',publishpath) - -if not train.modelSet(): - train.setModel(pretrain_model, - td = train.train_data.dataclass(), - debug_outdir=train.outputDir+'/intplots', - publishpath=publishpath) - - train.saveCheckPoint("before_training.h5") - train.setCustomOptimizer(tf.keras.optimizers.Adam()) - # - train.compileModel(learningrate=1e-4) - - train.keras_model.summary() - - #start somewhere - #from model_tools import apply_weights_from_path - #import os - #path_to_pretrained = os.getenv("HGCALML")+'/models/pre_selection_jan/KERAS_model.h5' - #train.keras_model = apply_weights_from_path(path_to_pretrained,train.keras_model) - - - - -verbosity = 2 -import os - -samplepath=train.val_data.getSamplePath(train.val_data.samples[0]) -# publishpath = 'jkiesele@lxplus.cern.ch:/eos/home-j/jkiesele/www/files/HGCalML_trainings/'+os.path.basename(os.path.normpath(train.outputDir)) - -from DeepJetCore.training.DeepJet_callbacks import simpleMetricsCallback - - - -cb = [ - - simpleMetricsCallback( - output_file=train.outputDir+'/reduction_metrics.html', - record_frequency= record_frequency, - plot_frequency = plot_frequency, - select_metrics=['*_reduction', '*_purity','*_cleaned_fraction','*contamination'],#includes time - publish=publishpath #no additional directory here (scp cannot create one) - ), - - - #simpleMetricsCallback( - # output_file=train.outputDir+'/hit_reduction_metrics.html', - # record_frequency= record_frequency, - # plot_frequency = plot_frequency, - # select_metrics=['*reduction*hits*','*_reduction*lost*'],#includes time - # publish=publishpath #no additional directory here (scp cannot create one) - # ), - # - simpleMetricsCallback( - output_file=train.outputDir+'/noise_metrics.html', - record_frequency= record_frequency, - plot_frequency = plot_frequency, - select_metrics=['*noise*accuracy','*noise*loss','*noise*reduction','*purity','*efficiency'], - publish=publishpath #no additional directory here (scp cannot create one) - ), - - - simpleMetricsCallback( - output_file=train.outputDir+'/time.html', - record_frequency= 2.*record_frequency,#doesn't change anyway - plot_frequency = plot_frequency, - select_metrics='*time*', - publish=publishpath #no additional directory here (scp cannot create one) - ), - - simpleMetricsCallback( - output_file=train.outputDir+'/losses.html', - record_frequency= record_frequency, - plot_frequency = plot_frequency, - select_metrics=['*_loss','*simple_knn_oc*'], - publish=publishpath #no additional directory here (scp cannot create one) - ), - - #simpleMetricsCallback( - # output_file=train.outputDir+'/gooey.html', - # record_frequency= record_frequency, - # plot_frequency = plot_frequency, - # select_metrics='*gooey*', - # publish=publishpath #no additional directory here (scp cannot create one) - # ), - - - simpleMetricsCallback( - output_file=train.outputDir+'/oc_thresh.html', - record_frequency= record_frequency, - plot_frequency = plot_frequency, - select_metrics='*_ll_*oc_thresholds*', - publish=publishpath #no additional directory here (scp cannot create one) - ), - - simpleMetricsCallback( - output_file=train.outputDir+'/val_metrics.html', - call_on_epoch=True, - select_metrics='val_*', - publish=publishpath #no additional directory here (scp cannot create one) - ), - - ] - -#cb=[] -nbatch = 150000 -train.change_learning_rate(5e-4) -train.trainModel(nepochs=1, batchsize=nbatch,additional_callbacks=cb) - -nbatch = 150000 -train.change_learning_rate(3e-5) -train.trainModel(nepochs=10,batchsize=nbatch,additional_callbacks=cb) - -print('reducing learning rate to 1e-4') -train.change_learning_rate(1e-5) -nbatch = 200000 - -train.trainModel(nepochs=100,batchsize=nbatch,additional_callbacks=cb) diff --git a/Train/pre_tiny_pc_pool_training.py b/Train/pre_tiny_pc_pool_training.py new file mode 100644 index 00000000..ab219e75 --- /dev/null +++ b/Train/pre_tiny_pc_pool_training.py @@ -0,0 +1,252 @@ +''' + + +Compatible with the dataset here: +/eos/home-j/jkiesele/ML4Reco/Gun20Part_NewMerge/train + +On flatiron: +/mnt/ceph/users/jkieseler/HGCalML_data/Gun20Part_NewMerge/train + +not compatible with datasets before end of Jan 2022 + +''' + +import globals +if False: #for testing + globals.acc_ops_use_tf_gradients = True + globals.knn_ops_use_tf_gradients = True + +import tensorflow as tf +# from K import Layer + +#from datastructures import TrainData_NanoML + +#from tensorflow.keras import Model +from Layers import DictModel, PlotCoordinates, Where +from tensorflow.keras.layers import Concatenate, Dense + +from model_blocks import tiny_pc_pool, condition_input +from GraphCondensationLayers import add_attention, PushUp +from callbacks import NanSweeper + + +plot_frequency= 40 # 150 #150 # 1000 #every 20 minutes approx +record_frequency = 20 + +reduction_target = 0.05 +lr_factor = 1. +nbatch = 170000 + +no_publish = False + +train_second = True +if train_second: + lr_factor = reduction_target/5. + nbatch = 170000 + +train_all = False + +def pretrain_model(Inputs, + td, + debugplots_after=record_frequency*plot_frequency, #10 minutes: ~600 + debug_outdir=None, + publishpath=None): + + orig_inputs = td.interpretAllModelInputs(Inputs,returndict=True) + presel = condition_input(orig_inputs, no_scaling=True) + + + + presel['prime_coords'] = PlotCoordinates(plot_every=debugplots_after, + outdir=debug_outdir,name='pc_pool_coords_pre', + publish=publishpath)( + [presel['prime_coords'], + presel['rechit_energy'], + presel['t_idx'],presel['row_splits']]) + + trans,presel = tiny_pc_pool(presel, + reduction_target = reduction_target, + trainable=not train_second or train_all, + record_metrics = True, + publish=publishpath, + debugplots_after=debugplots_after, + debug_outdir=debug_outdir, + ) + + + presel['cond_coords'] = PlotCoordinates(plot_every=debugplots_after, + outdir=debug_outdir,name='pc_pool_cond_coords0', + publish=publishpath)( + [presel['cond_coords'], + presel['rechit_energy'],#Where(0.5)([presel['is_track'],presel['rechit_energy']]), + presel['t_idx'],presel['row_splits']]) + + + presel['prime_coords'] = PlotCoordinates(plot_every=debugplots_after, + outdir=debug_outdir,name='pc_pool_post_prime0', + publish=publishpath)( + [presel['prime_coords'], + presel['rechit_energy'],#Where(0.5)([presel['is_track'],presel['rechit_energy']]), + presel['t_idx'],presel['row_splits']]) + + presel['select_prime_coords'] = PlotCoordinates(plot_every=debugplots_after, + outdir=debug_outdir,name='pc_pool_post_sel_prime', + publish=publishpath)( + [presel['select_prime_coords'], + presel['rechit_energy'],#Where(0.5)([presel['is_track'],presel['rechit_energy']]), + presel['t_idx'],presel['row_splits']]) + + + + if train_second: + + trans,presel = tiny_pc_pool(presel, + #coords = coords, + name='pre_graph_pool1', + is_second = True, + reduction_target=0.1, + trainable=True, + #coords = coords, + #low_energy_cut_target = 1.0, + record_metrics = True, + publish=publishpath, + debugplots_after=debugplots_after, + debug_outdir=debug_outdir + ) + + + presel['prime_coords'] = PlotCoordinates(plot_every=debugplots_after, + outdir=debug_outdir,name='pc_pool_post_prime1', + publish=publishpath)( + [presel['prime_coords'], + presel['rechit_energy'],#Where(1.)([presel['is_track'],presel['rechit_energy']]), + presel['t_idx'],presel['row_splits']]) + + + presel['cond_coords'] = PlotCoordinates(plot_every=debugplots_after, + outdir=debug_outdir,name='pc_pool_cond_coords1', + publish=publishpath)( + [presel['cond_coords'], + presel['rechit_energy'],#Where(0.5)([presel['is_track'],presel['rechit_energy']]), + presel['t_idx'],presel['row_splits']]) + + presel.update(trans) #put them all in + #presel.pop('row_splits') + return DictModel(inputs=Inputs, outputs=presel) + +import training_base_hgcal +train = training_base_hgcal.HGCalTraining() + +publishpath = "jkiesele@lxplus.cern.ch:~/Cernbox/www/files/temp/June2023/" +publishpath += [d for d in train.outputDir.split('/') if len(d)][-1] + +if no_publish: + publishpath = None + +print('will attempt to publish to',publishpath) + +if not train.modelSet(): + train.setModel(pretrain_model, + td = train.train_data.dataclass(), + debug_outdir=train.outputDir+'/intplots', + publishpath=publishpath) + + train.saveCheckPoint("before_training.h5") + train.setCustomOptimizer(tf.keras.optimizers.Adam(clipnorm=1.)) + # + train.compileModel(learningrate=1e-4) + + train.keras_model.summary() + + #start somewhere + #from model_tools import apply_weights_from_path + #import os + #path_to_pretrained = os.getenv("HGCALML")+'/models/pre_selection_jan/KERAS_model.h5' + #train.keras_model = apply_weights_from_path(path_to_pretrained,train.keras_model) + + + + +verbosity = 2 +import os + +samplepath=train.val_data.getSamplePath(train.val_data.samples[0]) +#publishpath = "jkiesele@lxplus.cern.ch:~/Cernbox/www/files/temp/June2023/"+os.path.basename(os.path.normpath(train.outputDir)) + +from DeepJetCore.training.DeepJet_callbacks import simpleMetricsCallback + + + +cb = [ + + NanSweeper(), + + simpleMetricsCallback( + output_file=train.outputDir+'/reduction_metrics.html', + record_frequency = record_frequency , + plot_frequency = plot_frequency, + select_metrics=['*_reduction'],#includes time + publish=publishpath #no additional directory here (scp cannot create one) + ), + + + simpleMetricsCallback( + output_file=train.outputDir+'/hit_reduction_metrics.html', + record_frequency = record_frequency , + plot_frequency = plot_frequency, + select_metrics=['*hits*','*lost*','*tracks*'],#includes time + publish=publishpath #no additional directory here (scp cannot create one) + ), + + + simpleMetricsCallback( + output_file=train.outputDir+'/losses.html', + record_frequency = record_frequency , + plot_frequency = plot_frequency, + select_metrics='*_loss', + publish=publishpath #no additional directory here (scp cannot create one) + ), + + + + simpleMetricsCallback( + output_file=train.outputDir+'/val_metrics.html', + call_on_epoch=True, + select_metrics='val_*', + publish=publishpath #no additional directory here (scp cannot create one) + ), + + + simpleMetricsCallback( + output_file=train.outputDir+'/batchnorm.html', + call_on_epoch=True, + select_metrics='*norm*', + publish=publishpath #no additional directory here (scp cannot create one) + ), + + ] + +#cb=[] + +train.change_learning_rate(lr_factor*2e-3) +train.trainModel(nepochs=1, batchsize=nbatch,additional_callbacks=cb) + +train.change_learning_rate(lr_factor*1e-3) +train.trainModel(nepochs=10, batchsize=nbatch,additional_callbacks=cb) + +train.change_learning_rate(lr_factor*1e-4) +train.trainModel(nepochs=60, batchsize=nbatch,additional_callbacks=cb) + +train.change_learning_rate(lr_factor*1e-5) +train.trainModel(nepochs=80, batchsize=nbatch,additional_callbacks=cb) + +exit() #done +#nbatch = 150000 +train.change_learning_rate(3e-4) +train.trainModel(nepochs=10,batchsize=nbatch,additional_callbacks=cb) + +print('reducing learning rate to 1e-4') +train.change_learning_rate(1e-5) +#nbatch = 200000 + +train.trainModel(nepochs=100,batchsize=nbatch,additional_callbacks=cb) diff --git a/Train/pz_test.py b/Train/pz_test.py new file mode 100755 index 00000000..ac0fac5c --- /dev/null +++ b/Train/pz_test.py @@ -0,0 +1,436 @@ +''' +Intended to be used on toy data set found on FI +/eos/home-p/phzehetn/ML4Reco/Data/V4/Train_cut11/dataCollection.djcdc + +As of November 10th, 2022 both classification loss and timing loss do not +work and should be left at 0.0 in the LOSS_OPTIONS +''' + + +import globals +if True: #for testing + #globals.acc_ops_use_tf_gradients = True + globals.knn_ops_use_tf_gradients = True + +import tensorflow as tf +from tensorflow.keras.layers import Dense, Concatenate, Dropout, Dropout + +import training_base_hgcal +from DeepJetCore.training.DeepJet_callbacks import simpleMetricsCallback +from DeepJetCore.DJCLayers import StopGradient +from datastructures import TrainData_PreselectionNanoML + +from Layers import RaggedGravNet, RaggedGlobalExchange +from Layers import DistanceWeightedMessagePassing +from Layers import DictModel, SphereActivation, Multi +from Layers import CastRowSplits, PlotCoordinates +from Layers import LLFullObjectCondensation as LLExtendedObjectCondensation +from Layers import ScaledGooeyBatchNorm2, Sqrt +from Layers import LLFillSpace, SphereActivation +from Regularizers import AverageDistanceRegularizer +from model_blocks import create_outputs +from model_blocks import extent_coords_if_needed +from model_blocks import tiny_pc_pool, condition_input +from model_tools import apply_weights_from_path + +from callbacks import NanSweeper, DebugPlotRunner +from Layers import layernorm +import os + + +############################################################################### +### Configure model and training here ######################################### +############################################################################### + +LOSS_OPTIONS = { + 'energy_loss_weight': 1e-6, + 'q_min': 0.5, + 'use_average_cc_pos': 0.5, + 'classification_loss_weight':0., # to make it work0.5, + 'too_much_beta_scale': 0.0, + 'position_loss_weight':0., + 'timing_loss_weight':0.0, + 'beta_loss_scale':1., #2.0 + 'implementation': 'hinge' #'hinge_manhatten'#'hinge'#old school + } + +BATCHNORM_OPTIONS = { + 'max_viscosity': 0.5 #keep very batchnorm like + } + +# Configuration for model +PRESELECTION_PATH = os.getenv("HGCALML")+'/models/tiny_pc_pool/model_no_nan.h5'#model.h5' + +# Configuration for plotting +RECORD_FREQUENCY = 20 +PLOT_FREQUENCY = 10 #plots every 200 batches -> roughly 3 minutes +PUBLISHPATH = "jkiesele@lxplus.cern.ch:~/Cernbox/www/files/temp/June2023/" +#PUBLISHPATH = None + +# Configuration for training +DENSE_ACTIVATION='elu' #layernorm #'elu' +LEARNINGRATE = 5e-3 +LEARNINGRATE2 = 1e-3 +LEARNINGRATE3 = 1e-4 +NBATCH = 200000#200000 +DENSE_REGULARIZER = tf.keras.regularizers.L2(l2=1e-5) +DENSE_REGULARIZER = None + +# Configuration of GravNet Blocks +N_NEIGHBOURS = [256, 256, 256] +TOTAL_ITERATIONS = len(N_NEIGHBOURS) +N_CLUSTER_SPACE_COORDINATES = 3 +N_GRAVNET = 3 + +############################################################################### +### Define model ############################################################## +############################################################################### + +def gravnet_model(Inputs, td, debug_outdir=None, + plot_debug_every=RECORD_FREQUENCY*PLOT_FREQUENCY, + publish = None): + ############################################################################ + ##################### Input processing, no need to change much here ######## + ############################################################################ + + pre_selection = td.interpretAllModelInputs(Inputs, returndict=True) + + pre_selection = condition_input(pre_selection, no_scaling=True) + trans, pre_selection = tiny_pc_pool( + pre_selection, + record_metrics=True, + #trainable=True + )#train in one go.. what is up with the weight loading? + + #just for info what's available + print('available pre-selection outputs',list(pre_selection.keys())) + + rs = pre_selection['row_splits'] + is_track = pre_selection['is_track'] + + x_in = Concatenate()([pre_selection['prime_coords'], + pre_selection['features']]) + + #x_in = Concatenate()([x_in, is_track, SphereActivation()(x_in)]) + x_in = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x_in) + x_in = Dense(128, activation=DENSE_ACTIVATION)(x_in) + x_in = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x_in) + x = x_in + energy = pre_selection['rechit_energy'] + c_coords = pre_selection['prime_coords']#pre-clustered coordinates + c_coords = ScaledGooeyBatchNorm2( + fluidity_decay=0.5, #can freeze almost immediately + )(c_coords) + t_idx = pre_selection['t_idx'] + + c_coords = PlotCoordinates( + plot_every=plot_debug_every, + outdir=debug_outdir, + name='input_c_coords', + publish = publish + )([c_coords, energy, t_idx, rs]) + + ############################################################################ + ##################### now the actual model goes below ###################### + ############################################################################ + + allfeat = [] + + #extend coordinates already here if needed, starting point for gravnet + + c_coords = extent_coords_if_needed(c_coords, x, N_GRAVNET) + + ## not needed, embedding already done in the pre-pooling + #x_track = Dense(64, + # activation=DENSE_ACTIVATION, + # kernel_regularizer=DENSE_REGULARIZER)(x) + #x_hit = Dense(64, + # activation=DENSE_ACTIVATION, + # kernel_regularizer=DENSE_REGULARIZER)(x) + #is_track_bool = tf.cast(is_track, tf.bool) + #x = tf.where(is_track_bool, x_track, x_hit) + + for i in range(TOTAL_ITERATIONS): + + #x,n = SphereActivation(return_norm=True)(x) + + x = Dense(64,activation=DENSE_ACTIVATION, + kernel_regularizer=DENSE_REGULARIZER)(x) + #x = Dropout(0.1)(x) + x = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x) + x = Dense(64,activation=DENSE_ACTIVATION, + kernel_regularizer=DENSE_REGULARIZER)(x) + + x = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x) + #x = Dropout(0.1)(x) + + x = Concatenate()([c_coords,x]) + + xgn, gncoords, gnnidx, gndist = RaggedGravNet( + n_neighbours=N_NEIGHBOURS[i], + n_dimensions=N_GRAVNET, + n_filters=64, + n_propagate=64, + coord_initialiser_noise=1e-5, + #sumwnorm = True, + )([x, rs]) + + + xgn = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(xgn) + + gndist = AverageDistanceRegularizer( + strength=1e-2, + record_metrics=True + )(gndist) + + gncoords = PlotCoordinates( + plot_every=plot_debug_every, + outdir=debug_outdir, + name='gn_coords_'+str(i), + publish = publish + )([gncoords, energy, t_idx, rs]) + gncoords = StopGradient()(gncoords) + + x = Concatenate()([gncoords,xgn,x]) + + #does the same but with batch norm + for nn in [64,64,64,64]: + + #d_mult = ScalarMultiply(2.)(Dense(1,activation='sigmoid')(x)) + #gndist = Multi()([gndist,d_mult])#scale distances here dynamically + x = SphereActivation()(x) + x = DistanceWeightedMessagePassing( + [nn], + activation=DENSE_ACTIVATION, + #sumwnorm = True, + )([x, gnnidx, gndist]) + x = Dense(128,activation=DENSE_ACTIVATION, + kernel_regularizer=DENSE_REGULARIZER)(x) + #x = Dropout(0.1)(x) + # + + x = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x) + x = Dense(64,name='dense_past_mp_'+str(i),activation=DENSE_ACTIVATION, + kernel_regularizer=DENSE_REGULARIZER)(x) + x = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x) + #x = Dropout(0.25)(x) + x = Dense(64,activation=DENSE_ACTIVATION, + kernel_regularizer=DENSE_REGULARIZER)(x) + + #x = Multi()([x,n]) #back to full space + x = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x) + + allfeat.append(x) + + x = Concatenate()(allfeat) + x = RaggedGlobalExchange()([x,rs]) + #x = Concatenate()([x,SphereActivation()(x)]) + x = Dense(64, name='Last_Dense_1', activation=DENSE_ACTIVATION)(x) + x = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x) + x = Dense(64, name='Last_Dense_2', activation=DENSE_ACTIVATION)(x) + x = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x) + #x = Dropout(0.1)(x) + x = Dense(64, name='Last_Dense_3', activation=DENSE_ACTIVATION)(x)#we want this to be not bounded + #x = Dropout(0.1)(x) + ########################################################################### + ########### the part below should remain almost unchanged ################# + ########### of course with the exception of the OC loss ################# + ########### weights ################# + ########################################################################### + + x = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x) + # x = Concatenate()([x]) + + pred_beta, pred_ccoords, pred_dist, \ + pred_energy_corr, pred_energy_low_quantile, pred_energy_high_quantile, \ + pred_pos, pred_time, pred_time_unc, pred_id = \ + create_outputs(x, n_ccoords=N_CLUSTER_SPACE_COORDINATES, fix_distance_scale=True) + + pred_ccoords = LLFillSpace(maxhits=1000, runevery=1, + scale=0.01, + record_metrics=True, + print_loss=True, + print_batch_time=True)([pred_ccoords, rs, + pre_selection['t_idx']]) + + + # loss + pred_beta = LLExtendedObjectCondensation( + scale=1., + use_energy_weights=False,#well distributed anyways + record_metrics=True, + print_loss=True, + name="ExtendedOCLoss", + **LOSS_OPTIONS + )( # oc output and payload + [pred_beta, + pred_ccoords, + pred_dist, + pred_energy_corr, + pred_energy_low_quantile, + pred_energy_high_quantile, + pred_pos, + pred_time, + pred_time_unc, + pred_id] + + [energy] + + # truth information + [pre_selection['t_idx'] , + pre_selection['t_energy'] , + pre_selection['t_pos'] , + pre_selection['t_time'] , + pre_selection['t_pid'] , + pre_selection['t_spectator_weight'], + pre_selection['t_fully_contained'], + pre_selection['t_rec_energy'], + pre_selection['t_is_unique'], + pre_selection['row_splits'] ] + ) + + #fast feedback + pred_ccoords = PlotCoordinates( + plot_every=plot_debug_every, + outdir = debug_outdir, + name='condensation', + publish = publish + )([pred_ccoords, pred_beta,pre_selection['t_idx'], rs]) + model_outputs = { + 'pred_beta': pred_beta, + 'pred_ccoords': pred_ccoords, + 'pred_energy_corr_factor': pred_energy_corr, + 'pred_energy_low_quantile': pred_energy_low_quantile, + 'pred_energy_high_quantile': pred_energy_high_quantile, + 'pred_pos': pred_pos, + 'pred_time': pred_time, + 'pred_id': pred_id, + 'pred_dist': pred_dist, + 'rechit_energy': energy, + 'row_splits': pre_selection['row_splits'], #are these the selected ones or not? + 'no_noise_sel': trans['sel_idx_up'], + 'no_noise_rs': trans['rs_down'], #unclear what that actually means? + 'sel_idx': trans['sel_idx_up'], #just a duplication but more intuitive to understand + 'sel_t_idx': pre_selection['t_idx'] #for convenience + # 'noise_backscatter': pre_selection['noise_backscatter'], + } + + return DictModel(inputs=Inputs, outputs=model_outputs) + +############################################################################### +### Model defined, set up training ############################################ +############################################################################### + +train = training_base_hgcal.HGCalTraining() + +if PUBLISHPATH is not None: + PUBLISHPATH += [d for d in train.outputDir.split('/') if len(d)][-1] + +if not train.modelSet(): + train.setModel( + gravnet_model, + td=train.train_data.dataclass(), + debug_outdir=train.outputDir+'/intplots', + publish = PUBLISHPATH + ) + train.setCustomOptimizer(tf.keras.optimizers.Nadam(clipnorm=2., + epsilon=1e-2)) + train.compileModel(learningrate=LEARNINGRATE) + train.keras_model.summary() + + if not isinstance(train.train_data.dataclass(), TrainData_PreselectionNanoML): + train.keras_model = apply_weights_from_path(PRESELECTION_PATH, train.keras_model) + + #exit() + +############################################################################### +### Create Callbacks ########################################################## +############################################################################### + +val_samplepath = train.val_data.getSamplePath(train.val_data.samples[0]) +cb = [] + + +cb += [ + + NanSweeper(),#this takes a bit of time checking each batch but could be worth it + + simpleMetricsCallback( + output_file=train.outputDir+'/metrics.html', + record_frequency= RECORD_FREQUENCY, + plot_frequency = PLOT_FREQUENCY, + select_metrics=['ExtendedOCLoss*','FullOCLoss_*loss','*ll_fill_space*'], + publish=PUBLISHPATH #no additional directory here (scp cannot create one) + ), + + + simpleMetricsCallback( + output_file=train.outputDir+'/gndist.html', + record_frequency= RECORD_FREQUENCY, + plot_frequency = PLOT_FREQUENCY, + select_metrics=['*average_distance*'], + publish=PUBLISHPATH #no additional directory here (scp cannot create one) + ), + + # collect all pre pooling metrics here + simpleMetricsCallback( + output_file=train.outputDir+'/pgp_metrics.html', + record_frequency= RECORD_FREQUENCY, + plot_frequency = PLOT_FREQUENCY, + select_metrics='*pre_graph_pool*', + publish=PUBLISHPATH + ), + + simpleMetricsCallback( + output_file=train.outputDir+'/val_metrics.html', + call_on_epoch=True, + select_metrics='val_*', + publish=PUBLISHPATH #no additional directory here (scp cannot create one) + ), + + #triggers debug plots within the model on a specific sample + DebugPlotRunner( + plot_frequency = 2, #testing + sample = val_samplepath + ) + ] + + +############################################################################### +### Start training ############################################################ +############################################################################### + +print("Batch size: ", NBATCH) +train.change_learning_rate(LEARNINGRATE) +model, history = train.trainModel( + nepochs=2, + batchsize=NBATCH, + additional_callbacks=cb + ) + +train.change_learning_rate(LEARNINGRATE2) +model, history = train.trainModel( + nepochs=4, + batchsize=NBATCH, + additional_callbacks=cb + ) + + +train.change_learning_rate(LEARNINGRATE3) +model, history = train.trainModel( + nepochs=6, + batchsize=NBATCH, + additional_callbacks=cb + ) + +for l in train.keras_model.layers: + if isinstance(l, ScaledGooeyBatchNorm2): + l.trainable = False + +train.compileModel(learningrate=LEARNINGRATE3) + +model, history = train.trainModel( + nepochs=8, + batchsize=NBATCH, + additional_callbacks=cb + ) + +exit() diff --git a/models/tiny_pc_pool/model.h5 b/models/tiny_pc_pool/model.h5 new file mode 100644 index 00000000..534cb2e4 Binary files /dev/null and b/models/tiny_pc_pool/model.h5 differ diff --git a/models/tiny_pc_pool/modules.tar.gz b/models/tiny_pc_pool/modules.tar.gz new file mode 100644 index 00000000..93336582 Binary files /dev/null and b/models/tiny_pc_pool/modules.tar.gz differ diff --git a/modules/ActivationLayers.py b/modules/ActivationLayers.py new file mode 100755 index 00000000..2ceffa8c --- /dev/null +++ b/modules/ActivationLayers.py @@ -0,0 +1,28 @@ + +import tensorflow as tf + +class GroupSort(tf.keras.layers.Layer): + + def compute_output_shape(self, input_shapes): + return input_shapes + + def call(self, input): + + return tf.sort(input, axis=-1) + + +class Sphere(tf.keras.layers.Layer): + + def compute_output_shape(self, input_shapes): + out = [] + for s in input_shapes: + out.append(s) + out[-1] += 1 + return out + + def call(self, x): + norm = tf.reduce_sum(x**2, axis=-1,keepdims=True) + norm = tf.sqrt(norm+1e-6) + x = tf.concat([x / norm, norm], axis=-1) + return input + \ No newline at end of file diff --git a/modules/DebugLayers.py b/modules/DebugLayers.py index 24c0309b..7b2b791e 100644 --- a/modules/DebugLayers.py +++ b/modules/DebugLayers.py @@ -16,12 +16,13 @@ import os class CumulativeArray(object): - def __init__(self, capacity = 60, default=0.): + def __init__(self, capacity = 60, default=0., name=None): assert capacity > 0 self.data = None self.capacity = capacity self.default = default + self.name = name def put(self, arr): arr = np.where(arr == np.nan, self.default, arr) @@ -109,6 +110,7 @@ def __init__(self, outdir :str='' , plot_only_training=True, publish = None, + externally_triggered = False, **kwargs): if 'dynamic' in kwargs: @@ -117,6 +119,8 @@ def __init__(self, super(_DebugPlotBase, self).__init__(dynamic=False,**kwargs) self.plot_every = plot_every + self.externally_triggered = externally_triggered + self.triggered = False self.plot_only_training = plot_only_training if len(outdir) < 1: self.plot_every=0 @@ -131,7 +135,8 @@ def __init__(self, def get_config(self): config = {'plot_every': self.plot_every, 'outdir': self.outdir, - 'publish': self.publish} + 'publish': self.publish, + 'externally_triggered': self.externally_triggered} base_config = super(_DebugPlotBase, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -149,6 +154,10 @@ def create_base_output_path(self): return self.outdir+'/'+self.name def check_make_plot(self, inputs, training = None): + + if self.externally_triggered: + return self.triggered + out=inputs if isinstance(inputs,list): out=inputs[0] @@ -179,6 +188,7 @@ def call(self, inputs, training=None): out=inputs if isinstance(inputs,list): out=inputs[0] + self.add_loss(0. * tf.reduce_sum(out[0]))#to keep it alive if not self.check_make_plot(inputs, training): return out @@ -344,9 +354,13 @@ def plot(self, inputs, training=None): elif len(inputs) == 6: coords, features, hoverfeat, nidx, tidx, rs = inputs + #give each an index + idxs = np.arange(features.shape[0]) + #just select first coords = coords[0:rs[1]] tidx = tidx[0:rs[1]] + idxs = idxs[0:rs[1]] if len(tidx.shape) <2: tidx = tidx[...,tf.newaxis] features = features[0:rs[1]] @@ -367,7 +381,8 @@ def plot(self, inputs, training=None): 'Y': coords[:,1+i:2+i].numpy(), 'Z': coords[:,2+i:3+i].numpy(), 'tIdx': tidx[:,0:1].numpy(), - 'features': features[:,0:1].numpy() + 'features': features[:,0:1].numpy(), + 'idx' : idxs[...,np.newaxis] } hoverdict={} if hoverfeat is not None: @@ -383,7 +398,7 @@ def plot(self, inputs, training=None): rdst = np.random.RandomState(1234567890)#all the same shuffle_truth_colors(df,'tIdx',rdst) - hover_data=['orig_tIdx']+[k for k in hoverdict.keys()] + hover_data=['orig_tIdx','idx']+[k for k in hoverdict.keys()] if nidx is not None: hover_data.append('av_same') fig = px.scatter_3d(df, x="X", y="Y", z="Z", @@ -544,7 +559,10 @@ def plot(self, inputs, training=None): class PlotGraphCondensationEfficiency(_DebugPlotBase): - def __init__(self, update = 0.1, **kwargs): + def __init__(self, + accumulate_every :int = 10 , #how + externally_triggered = False, + **kwargs): ''' Inputs: - t_energy @@ -555,9 +573,25 @@ def __init__(self, update = 0.1, **kwargs): - t_energy ''' - super(PlotGraphCondensationEfficiency, self).__init__(**kwargs) - self.num = CumulativeArray(40) - self.den = CumulativeArray(40) + super(PlotGraphCondensationEfficiency, self).__init__(externally_triggered=externally_triggered, + **kwargs) + + self.acc_counter = 0 + self.accumulate_every = accumulate_every + + self.only_accumulate_this_time = False + + accumulate = self.plot_every // accumulate_every + 50 + + self.num = CumulativeArray(accumulate, name = self.name+'_num') + self.den = CumulativeArray(accumulate, name = self.name+'_den') + + + + def get_config(self): + config = {'accumulate_every': self.accumulate_every}#outdir/publish is explicitly not saved and needs to be set again every time + base_config = super(PlotGraphCondensationEfficiency, self).get_config() + return dict(list(base_config.items()) + list(config.items())) #overwrite here def call(self, t_energy, t_idx, graph_trans , training=None): @@ -567,7 +601,6 @@ def call(self, t_energy, t_idx, graph_trans , training=None): os.system('mkdir -p '+self.outdir) try: - print(self.name, 'plotting...') self.plot(t_energy, t_idx, graph_trans,training) except Exception as e: raise e @@ -575,6 +608,25 @@ def call(self, t_energy, t_idx, graph_trans , training=None): return t_energy + def check_make_plot(self, inputs, training = None): + pre = super(PlotGraphCondensationEfficiency, self).check_make_plot(inputs, training) + + if self.plot_every <= 0 and not self.externally_triggered: #nothing + return pre + + self.only_accumulate_this_time = False + #OR: + if self.accumulate_every < self.acc_counter: + self.acc_counter = 0 + self.only_accumulate_this_time = not pre + + return True + + self.acc_counter += 1 + return pre + + + def plot(self, t_energy, t_idx, graph_trans, training=None): ''' @@ -623,10 +675,14 @@ def plot(self, t_energy, t_idx, graph_trans, training=None): h_orig, _ = np.histogram(orig_energies, bins = bins) h_orig = np.array(h_orig, dtype='float32') - self.den.put(h) + self.den.put(h_orig) - ##interface to old code + if self.only_accumulate_this_time: + return + + print(self.name, 'plotting...') + ##interface to old code h = self.num.get() h_orig = self.den.get() @@ -646,7 +702,7 @@ def plot(self, t_energy, t_idx, graph_trans, training=None): fig.write_html(self.outdir+'/'+self.name+'.html') if self.publish is not None: - publish(self.outdir+'/'+self.name+'.html', self.publish) + publish(self.outdir+'/'+self.name+'.html', self.publish) diff --git a/modules/GraphCondensationLayers.py b/modules/GraphCondensationLayers.py index 35365f1d..7483ee97 100755 --- a/modules/GraphCondensationLayers.py +++ b/modules/GraphCondensationLayers.py @@ -59,31 +59,80 @@ def __init__(self,*args,**kwargs): super().__init__(*args,**kwargs) ## just for convenience ## + def check(self): + assert self['weights_down'].shape[1] == self['nidx_down'].shape[1] +from GravNetLayersRagged import SortAndSelectNeighbours + class CreateGraphCondensation(tf.keras.layers.Layer): def __init__(self, K=5, score_threshold=0.5, + reduction_target=None, n_knn_bins = 21, safeguard = True, #makes sure there are never no points selected per row split + print_reduction = False, **kwargs): super(CreateGraphCondensation, self).__init__(**kwargs) self.K = K self.score_threshold = score_threshold + if reduction_target is not None: + assert reduction_target > 0 and reduction_target < 1. + self.reduction_target = reduction_target self.n_knn_bins = n_knn_bins self.safeguard = safeguard + self.print_reduction = print_reduction def get_config(self): - config = {'K': self.K, 'score_threshold': self.score_threshold, 'n_knn_bins': self.n_knn_bins, 'safeguard': self.safeguard} + config = {'K': self.K, 'score_threshold': self.score_threshold, + 'reduction_target': self.reduction_target, + 'n_knn_bins': self.n_knn_bins, 'safeguard': self.safeguard, 'print_reduction': self.print_reduction} base_config = super(CreateGraphCondensation, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + + def build(self,input_shape): - def call(self, score, coords, rs, always_promote=None, training = None): + def _init(shape, dtype=None): + return tf.constant(self.score_threshold)[...,tf.newaxis] + + self.dyn_score_threshold = self.add_weight(name = 'dyn_th', shape=(1, ), + initializer = _init, + constraint = 'non_neg', + trainable = False) + + super(CreateGraphCondensation, self).build(input_shape) + + def update_thresh(self, trans, training): + if self.reduction_target is None: + return + if not self.trainable: #establish expected behaviour + return + smoothness = 10. #hard coded, but should be fine + + red = tf.cast(trans['rs_up'][-1], + dtype='float32') / tf.cast(trans['rs_down'][-1], + dtype='float32') + red_diff = red - self.reduction_target # < 1, > -1 + step_up = (1. - self.dyn_score_threshold) * red_diff + step_down = self.dyn_score_threshold * red_diff + #if reduction is larger than target (diff < 0), score needs to step up + step = tf.where(red_diff > 0., step_up, step_down) + score_update = self.dyn_score_threshold + step / smoothness #slight reduction for safety + #update only in training phase + score_update = tf.keras.backend.in_train_phase(score_update, + self.dyn_score_threshold, + training=training) + tf.keras.backend.update(self.dyn_score_threshold,score_update) + + tf.print(self.name, 'dyn th',self.dyn_score_threshold, 'red', red, 'target', self.reduction_target) + + def call(self, score, coords, rs, nidx = None, dist = None, always_promote=None, training = None): trans = GraphCondensation() trans['rs_down'] = rs @@ -98,12 +147,13 @@ def call(self, score, coords, rs, always_promote=None, training = None): - any other number: can be neighbour and have neighbours ''' - direction = tf.where(score > self.score_threshold, 0, direction) + direction = tf.where(score > self.dyn_score_threshold, 0, direction) if always_promote is not None: - direction = tf.where(always_promote>0, 2, direction) + direction = tf.where(always_promote>0, 2, direction) #this should be a 2!! score = tf.where(always_promote>0, 1., score) + #make this indices for gather and scatter sel = tf.range(tf.shape(score)[0])[...,tf.newaxis] @@ -111,20 +161,20 @@ def call(self, score, coords, rs, always_promote=None, training = None): rsel = tf.RaggedTensor.from_row_splits(sel,rs) rscore = tf.RaggedTensor.from_row_splits(score,rs) - threshold = self.score_threshold + threshold = self.dyn_score_threshold #make sure there is something left, bad with very inhomogenous batches, but good for training if self.safeguard: mrss = tf.reduce_max(rscore,axis=1, keepdims=True) threshold = tf.reduce_min( tf.concat( [ tf.reduce_min(mrss)[tf.newaxis]*0.98, - tf.constant(self.score_threshold)[tf.newaxis] ], axis=0)) + self.dyn_score_threshold ], axis=0)) rsel = tf.ragged.boolean_mask(rsel, rscore[...,0] >= threshold) #use ragged to select trans['rs_up'] = tf.cast(rsel.row_splits,'int32')#for whatever reason - print(self.name, 'rs down',trans['rs_down']) - print(self.name, 'rs up',trans['rs_up']) + #print(self.name, 'rs down',trans['rs_down']) + #print(self.name, 'rs up',trans['rs_up']) #undo ragged trans['sel_idx_up'] = rsel.values @@ -132,8 +182,13 @@ def call(self, score, coords, rs, always_promote=None, training = None): tf.assert_greater(tf.shape(score)[0]+1, tf.shape(trans['sel_idx_up'])[0])]): - nidx, dist = select_knn(self.K+1, coords, rs, direction = direction, + if (nidx is not None) and (dist is not None): + dist, nidx = SortAndSelectNeighbours.raw_call(dist,nidx, K=self.K+1) + raise ValueError("not implemented yet. Needs cleaning w.r.t. directions.") + else: #yes this is swapped ordering + nidx, dist = select_knn(self.K+1, coords, rs, direction = direction, n_bins = self.n_knn_bins, name=self.name) + nidx = tf.reshape(nidx, [-1, self.K+1]) #to define shape for later dist = tf.reshape(dist, [-1, self.K+1]) dist = tf.where(nidx<0,0.,dist)#needed? @@ -162,10 +217,14 @@ def call(self, score, coords, rs, always_promote=None, training = None): trans['distsq_down'] = tf.reshape(trans['distsq_down'], [-1, self.K]) trans['weights_down'] = tf.reshape(trans['weights_down'], [-1, self.K]) + if self.print_reduction: + print(self.name, 'reduction', trans['rs_up'][-1], 'from', trans['rs_down'][-1], + tf.cast(trans['rs_up'][-1],dtype='float32')/tf.cast(trans['rs_down'][-1],dtype='float32') * 100, '%') #curiosity: #print(self.name, 'max number of assigned:', tf.reduce_max( tf.unique_with_counts( tf.reshape(trans['nidx_down'], [-1]) )[2] )) #trans.check_filled() # just during debugging + self.update_thresh(trans,training) return trans @@ -249,7 +308,7 @@ def call(self,features, transition : GraphCondensation, weight = None): up_f = features if weight is not None: - weight = tf.nn.relu(weight) #safe guard + weight = tf.nn.relu(weight) # + 1e-4 #safe guard, assume weights are O(1) up_f *= weight if self.mode == 'mean': @@ -259,14 +318,18 @@ def call(self,features, transition : GraphCondensation, weight = None): nidx = transition['nidx_down'] nweights = transition['weights_down'] + if self.add_self: - nidx = tf.concat([tf.range(tf.shape(nidx)[0])[:,tf.newaxis], nidx],axis=1) - nweights = tf.concat([tf.ones_like(nweights[:,0:1]), nweights],axis=1) + snidx = tf.concat([tf.range(tf.shape(nidx)[0])[:,tf.newaxis], nidx[:,1:]*0 -1 ],axis=1) + is_up = nidx[:,0:1] < 0 + nidx = tf.where(is_up, snidx, nidx) up_f = push_sum(nweights, up_f, nidx) up_f = tf.gather_nd(up_f, transition['sel_idx_up']) - if self.mode == 'mean': - up_f = tf.math.divide_no_nan(up_f[:,1:] , up_f[:,0:1] + 1e-3 ) + if self.mode == 'mean': + wsum = tf.nn.relu(up_f[:,0:1]) #just to catch numerics + wsum = tf.where(wsum>0., wsum, 1e-3) + up_f = tf.math.divide_no_nan(up_f[:,1:] , wsum) up_f = tf.reshape(up_f, [-1, features.shape[1]])#just so the shapes are defined upon placeholder call return up_f @@ -349,6 +412,26 @@ def call(self, features, transition : GraphCondensation, weights=None): graph_condensation_layers['PullDown'] = PullDown + +class SelectDown(tf.keras.layers.Layer): + + def call(self, features, transition : GraphCondensation): + + #simply copied down + down_f = tf.scatter_nd(transition['sel_idx_up'], + features, + shape = [tf.shape(transition['weights_down'])[0], + tf.shape(features)[1]]) + + nidx = transition['nidx_down'] + out = tf.reshape(select(nidx, down_f, 0.), [tf.shape(nidx)[0], features.shape[1] * nidx.shape[1]]) + print(self.name,'out shape',out.shape) + return out + +graph_condensation_layers['SelectDown'] = SelectDown + + + class Mix(tf.keras.layers.Layer): ''' Simply mixes the connected 'up' features with the 'down' features. @@ -542,16 +625,19 @@ def __init__(self, with tf.name_scope(self.name + "/0/"): - self.pre_dense = tf.keras.layers.Dense(self.pre_nodes, activation='elu') + self.pre_dense = tf.keras.layers.Dense(self.pre_nodes, activation='elu', + trainable = self.trainable) self.edge_dense = [] self.self_dense = [] for i,n in enumerate(edge_dense): with tf.name_scope(self.name + "/1/" + str(i)): - self.edge_dense.append(tf.keras.layers.Dense(n, activation='elu')) + self.edge_dense.append(tf.keras.layers.Dense(n, activation='elu', + trainable = self.trainable)) with tf.name_scope(self.name + "/1/" + str(i+1)): - self.edge_dense.append(tf.keras.layers.Dense(self.K)) + self.edge_dense.append(tf.keras.layers.Dense(self.K, + trainable = self.trainable)) if self_dense is None: self_dense = edge_dense @@ -560,10 +646,12 @@ def __init__(self, for i,n in enumerate(self_dense): with tf.name_scope(self.name + "/2/" + str(i)): - self.self_dense.append(tf.keras.layers.Dense(n, activation='elu')) + self.self_dense.append(tf.keras.layers.Dense(n, activation='elu', + trainable = self.trainable)) with tf.name_scope(self.name + "/2/" + str(i+1)): - self.self_dense.append(tf.keras.layers.Dense(1)) + self.self_dense.append(tf.keras.layers.Dense(1, + trainable = self.trainable)) def get_config(self): config = {'edge_dense': [self.edge_dense[i].units for i in range(len(self.edge_dense)-1)], @@ -609,13 +697,14 @@ def call(self, x, transition : GraphCondensation): for d in self.edge_dense: x_n = d(x_n) - x_s = x - for d in self.self_dense: - x_s = d(x_s) + if self.no_self: x = x_n else: + x_s = x + for d in self.self_dense: + x_s = d(x_s) x = tf.concat([x_s, x_n],axis=1) x = tf.nn.softmax(x,axis=1) @@ -994,7 +1083,7 @@ def __init__(self, ''' super(LLGraphCondensationEdges, self).__init__(**kwargs) - self.cce = tf.keras.losses.CategoricalCrossentropy() + self.cce = tf.keras.losses.CategoricalCrossentropy(reduction = tf.keras.losses.Reduction.NONE) def call(self, x_e, trans, t_idx): @@ -1007,24 +1096,34 @@ def loss(self, inputs): assert len(inputs) == 4 e_score, dist, nidx, t_idx = inputs + dist = tf.stop_gradient(dist) #important + K = nidx.shape[1] def one_hot_default(): ones = tf.ones_like(t_idx, dtype='float32') # V x 1 zeros = tf.zeros_like(nidx, dtype='float32') # V x K return tf.concat([ones,zeros],axis=1) # V x K + 1 - # n_t_idx = select(nidx, t_idx, -2) n_same = t_idx[:,:,tf.newaxis] == n_t_idx #V x K x 1 - dist = tf.where(nidx < 0, 1e6, dist) + #all_same = tf.logi + no_n_mask = nidx < 0 + dist = tf.where(no_n_mask, 1e6, dist) dist = tf.where(n_same[:,:,0], dist, 1e6) closest_and_same = tf.argmin( dist, axis=1) #no noise treatment yet #needs noise one_hot = tf.one_hot(closest_and_same + 1, K+1) one_hot = tf.where(t_idx < 0, one_hot_default(), one_hot) + #now where all are the same, and distances similar, don't apply loss + all_same = tf.reduce_all(n_same, axis=1)#all the same + distance_similar = (tf.reduce_max(dist, axis=1) - tf.reduce_min(dist, axis=1)) < 0.1 # assumes all normed to 1 + lossval = self.cce(one_hot, e_score) - return lossval + #if it doesn't matter anyway don't apply hard loss + lossval = tf.where( tf.logical_and(all_same, distance_similar), 0., lossval ) #tbi + + return tf.reduce_mean(lossval) graph_condensation_layers['LLGraphCondensationEdges'] = LLGraphCondensationEdges @@ -1039,16 +1138,110 @@ def __init__(self, **kwargs): - GraphCondensation - t_idx - t_energy + - is_track (opt) ''' super(MLGraphCondensationMetrics, self).__init__(**kwargs) - def call(self, graph_transition : GraphCondensation, t_idx, t_energy): + def call(self, graph_transition : GraphCondensation, t_idx, t_energy, is_track = None): gt = graph_transition - self.metrics_call([gt['sel_idx_up'], t_idx, t_energy, gt['rs_down'], gt['rs_up']]) + if is_track is None: + self.metrics_call([gt['sel_idx_up'], t_idx, t_energy, gt['rs_down'], gt['rs_up']]) + else: + self.metrics_call([gt['sel_idx_up'], t_idx, t_energy, is_track, gt['rs_down'], gt['rs_up']]) return graph_transition graph_condensation_layers['MLGraphCondensationMetrics'] = MLGraphCondensationMetrics + + + +# convenience function +from LossLayers import LLClusterCoordinates + +def add_attention(graph_transition, x, name, trainable = True): + a = graph_transition.copy() + att = tf.keras.layers.Dense(a['weights_down'].shape[1], activation='softmax', name=name, trainable = trainable)(x) + a['weights_down'] = att + return a + + +def point_pool(indict : dict, rs, name="p_pool", n_heads = 3, K_loss=64, trainable=True): + + #dict needs to contain: x, is_track, t_idx, t_spectator_weight + + x, is_track, t_idx, t_spectator_weight = indict['x'], indict['is_track'], indict['t_idx'], indict['t_spectator_weight'] + + score = tf.keras.layers.Dense(1, activation='sigmoid',name=name+'_gc_score', trainable = trainable)(x) + coords = tf.keras.layers.Dense(3, name=name+'_xyz_cond', use_bias = False, trainable = trainable)(x) + + coords = LLClusterCoordinates( + active = trainable, + downsample=5000, #no need to use all + scale = 1., + name = name+'_ll_cc', + ignore_noise = True, #this is filtered by the graph condensation anyway + hinge_mode = True + )([coords, t_idx, t_spectator_weight, + score, rs ]) + + score = LLGraphCondensationScore( + active = trainable, + name = name+'_ll_score', + K=K_loss, + )([score, coords, t_idx, rs]) + + trans_a = CreateGraphCondensation( + print_reduction = False, + name = name+'_gc_create', + score_threshold = 0.5, + K=5 + )(score,coords,rs, + always_promote = is_track) + + out = [] + for i in range(n_heads): + att = add_attention(trans_a, x, name+'_up_att_'+str(i), trainable = trainable) + out.append( PushUp()(x,att) ) + + odict = {} + for k in indict.keys(): + odict[k] = SelectUp()(indict[k],trans_a) + + out.append(odict['x']) + out = tf.keras.layers.Concatenate()(out) + odict['x'] = out + + return trans_a, odict, trans_a['rs_up'] #for backscatter + + +def point_scatter(x, trans : list, dense_nodes = 64, name = ""): + ''' + watch out -> this can become big + ''' + + trans = trans.copy() + trans.reverse() + for t in range(len(trans)): + ta = trans[t] + x = SelectDown()(x,ta) + x = tf.keras.layers.Dense(dense_nodes, activation='elu', name = name+'_d_'+str(t))(x) + + return x + + + + + + + + + + + + + + + diff --git a/modules/GravNetLayersRagged.py b/modules/GravNetLayersRagged.py index 84889e29..0febe90b 100644 --- a/modules/GravNetLayersRagged.py +++ b/modules/GravNetLayersRagged.py @@ -40,7 +40,8 @@ def AccumulateKnnSumw(distances, features, indices, mean_and_max=False): fmean = f[:,:origshape] fnorm = f[:,origshape:origshape+1] - fmean = tf.math.divide_no_nan(fmean,fnorm) + fnorm = tf.where(fnorm<1e-3, 1e-3, fnorm) + fmean = tf.math.divide_no_nan(fmean, fnorm) fmean = tf.reshape(fmean, [-1,origshape]) if mean_and_max: fmean = tf.concat([fmean, f[:,origshape+1:-1]],axis=1) @@ -216,7 +217,21 @@ def call(self, inputs ): class Where(tf.keras.layers.Layer): - def __init__(self, outputval, condition = '>0', **kwargs): + def __init__(self, outputval = None , condition = '>0', **kwargs): + ''' + Simple wrapper around tf.where. + + Inputs if outputval=None: + - tensor defining condition + - value to return if condition == True + - value to return else + + Inputs if outputval=val: + - tensor defining condition + - value to return if condition is not fulfilled + --> will return constant outputval=val if condition is fulfilled + + ''' conditions = ['>0','>=0','<0','<=0','==0', '!=0'] assert condition in conditions self.condition = condition @@ -237,20 +252,29 @@ def compute_output_shape(self, input_shapes): return (input_shapes[1],) def call(self,inputs): - assert len(inputs)==2 + + if self.outputval is not None: + assert len(inputs)==2 + left = self.outputval + right = inputs[1] + else: + assert len(inputs)==3 + left = inputs[1] + right = inputs[1] + izero = tf.constant(0,dtype=inputs[0].dtype) if self.condition == '>0': - return tf.where(inputs[0]> izero, self.outputval, inputs[1]) + return tf.where(inputs[0]> izero, left, right) elif self.condition == '>=0': - return tf.where(inputs[0]>=izero, self.outputval, inputs[1]) + return tf.where(inputs[0]>=izero, left, right) elif self.condition == '<0': - return tf.where(inputs[0]< izero, self.outputval, inputs[1]) + return tf.where(inputs[0]< izero, left, right) elif self.condition == '<=0': - return tf.where(inputs[0]<=izero, self.outputval, inputs[1]) + return tf.where(inputs[0]<=izero, left, right) elif self.condition == '!=0': - return tf.where(inputs[0]!=izero, self.outputval, inputs[1]) + return tf.where(inputs[0]!=izero, left, right) else: - return tf.where(inputs[0]==izero, self.outputval, inputs[1]) + return tf.where(inputs[0]==izero, left, right) class MixWhere(tf.keras.layers.Layer): @@ -1018,6 +1042,8 @@ def __init__(self, **kwargs): super(GooeyBatchNorm, self).__init__(**kwargs) + raise ValueError("Layer deprecated, please use ScaledGooeyBatchNorm2") + assert viscosity >= 0 and viscosity <= 1. assert fluidity_decay >= 0 and fluidity_decay <= 1. assert max_viscosity >= viscosity @@ -1177,14 +1203,18 @@ def call(self, inputs, training=None): s,v = tf.sign(inputs), tf.abs(inputs) out = super(SignedScaledGooeyBatchNorm, self).call(v, training) return s*out - -class ScaledGooeyBatchNorm2(LayerWithMetrics): + + +class ScaledGooeyBatchNorm2(tf.keras.layers.Layer): def __init__(self, viscosity=0.01, fluidity_decay=1e-4, - max_viscosity=0.99, + max_viscosity=0.99999, no_gaus = True, epsilon=1e-2, + invert_condition=False, + _promptnames=None, #compatibility, does nothing + record_metrics=False, #compatibility, does nothing **kwargs): ''' Input features (or [features, condition]), output: normed features @@ -1196,7 +1226,8 @@ def __init__(self, - fluidity_decay: 'thickening' of the viscosity (see scripts/gooey_plot.py for visualisation) - no_gaus: do not take variance but take mean difference to mean. Better for non-gaussian inputs and much more robust. - - epsilon: when dividing, added to the denominator (should not require adjustment) + - epsilon: when dividing, added to the denominator (should not require adjustment), + - invert_condition: instead of >0.5 uses <=0.5 ''' super(ScaledGooeyBatchNorm2, self).__init__(**kwargs) @@ -1211,6 +1242,7 @@ def __init__(self, self.viscosity_init = viscosity self.epsilon = epsilon self.no_gaus = no_gaus + self.invert_condition = invert_condition def compute_output_shape(self, input_shapes): #return input_shapes[0] @@ -1223,34 +1255,34 @@ def get_config(self): 'fluidity_decay': self.fluidity_decay, 'max_viscosity': self.max_viscosity, 'epsilon': self.epsilon, - 'no_gaus': self.no_gaus + 'no_gaus': self.no_gaus, + 'invert_condition': self.invert_condition } base_config = super(ScaledGooeyBatchNorm2, self).get_config() return dict(list(base_config.items()) + list(config.items())) def build(self, input_shapes): - + #shape = (1,)+input_shapes[0][1:] if isinstance(input_shapes,list): shape = (1,)+input_shapes[0][1:] else: shape = (1,)+input_shapes[1:] + + self.bias = self.add_weight(name = 'bias',shape = shape, + initializer = 'zeros', trainable = self.trainable) + self.gamma = self.add_weight(name = 'gamma',shape = shape, + initializer = 'ones', trainable = self.trainable) - self.mean = self.add_weight(name = 'mean',shape = shape, - initializer = 'zeros', trainable = False) - self.den = self.add_weight(name = 'den',shape = shape, + self.mean = self.add_weight(name = 'mean',shape = shape, + initializer = 'zeros', trainable = False) + self.den = self.add_weight(name = 'den',shape = shape, initializer = 'ones', trainable = False) - self.viscosity = tf.Variable(initial_value=self.viscosity_init, + self.viscosity = tf.Variable(initial_value=self.viscosity_init, name='viscosity', trainable=False,dtype='float32') - - self.bias = self.add_weight(name = 'bias',shape = shape, - initializer = 'zeros', trainable = self.trainable) - - self.gamma = self.add_weight(name = 'gamma',shape = shape, - initializer = 'ones', trainable = self.trainable) - + super(ScaledGooeyBatchNorm2, self).build(input_shapes) def _m_mean(self, x, mask): @@ -1283,7 +1315,10 @@ def _calc_out(self, x_in, cond): out = (x_in - ngmean) / (tf.abs(ngden) + self.epsilon) out = out*self.gamma + self.bias - return tf.where(cond>0.5, out, x_in) + if self.invert_condition: + return tf.where(cond<=0.5, out, x_in) + else: + return tf.where(cond>0.5, out, x_in) def call(self, inputs, training=None): if isinstance(inputs,list): @@ -1293,6 +1328,10 @@ def call(self, inputs, training=None): x_in = inputs cond = tf.ones_like(x_in[...,0:1]) + + #print(self.name, self.mean) + #tf.print(self.name, self.mean) + if not self.trainable: return self._calc_out(x_in, cond) @@ -1324,7 +1363,7 @@ def call(self, inputs, training=None): -class ConditionalScaledGooeyBatchNorm(LayerWithMetrics): +class ConditionalScaledGooeyBatchNorm(tf.keras.layers.Layer): def __init__(self,**kwargs): ''' Inputs (list): @@ -1337,38 +1376,9 @@ def __init__(self,**kwargs): Options: see ScaledGooeyBatchNorm2 options, will be passed as kwargs ''' - super(ConditionalScaledGooeyBatchNorm, self).__init__(**kwargs) - if 'name' in kwargs.keys(): - kwargs.pop('name') + raise ValueError("problems with weight saving, use two ScaledGooeyBatchNorm2 layers and invert condition on one.") - with tf.name_scope(self.name + "/1/"): - self.bn_a = ScaledGooeyBatchNorm2(name=self.name+'_bn_a',**kwargs) - with tf.name_scope(self.name + "/2/"): - self.bn_b = ScaledGooeyBatchNorm2(name=self.name+'_bn_b',**kwargs) - - def compute_output_shape(self, input_shapes): - #return input_shapes[0] - return self.bn_a.compute_output_shape(input_shapes) - - def build(self, input_shapes): - with tf.name_scope(self.name + "/1/"): - self.bn_a.build(input_shapes) - with tf.name_scope(self.name + "/2/"): - self.bn_b.build(input_shapes) - - super(ConditionalScaledGooeyBatchNorm, self).build(input_shapes) - - def call(self, inputs, training=None): - x, cond = inputs - cond = tf.where(cond > 0.5, tf.ones_like(cond), 0.) #make sure it's ones and zeros - - x_a = self.bn_a([x, cond],training = training) - x_b = self.bn_b([x, 1.-cond], training = training) - - return tf.where(cond>0.5, x_a, x_b) - - class ProcessFeatures(tf.keras.layers.Layer): def __init__(self, newformat=True,#compat can be restored but default is new format @@ -1996,8 +2006,8 @@ def compute_output_shape(self, input_shapes): def call(self,inputs): assert len(inputs)==2 coords, nidx = inputs - - ncoords = SelectWithDefault(nidx, coords, 0.) # V x K x C + #no check needed here + ncoords = SelectWithDefault(nidx, coords, 0., no_check=True) # V x K x C dist = tf.reduce_sum( (ncoords - tf.expand_dims(coords,axis=1))**2, axis=2 ) return dist @@ -2456,13 +2466,15 @@ def compute_output_signature(self, input_signature): return [tf.TensorSpec(dtype=input_dtypes[i], shape=output_shapes[i]) for i in range(len(output_shapes))] @staticmethod - def raw_call(distances, nidx, K, radius, sort, incr_sorting_score, keep_self=True): + def raw_call(distances, nidx, K, radius=-1, sort=True, incr_sorting_score = None, keep_self=True): K = K if K>0 else distances.shape[1] if not sort: return distances[:,:K],nidx[:,:K] - if tf.shape(incr_sorting_score)[1] is not None and tf.shape(incr_sorting_score)[1]==1: + if incr_sorting_score is None: + incr_sorting_score = distances + elif tf.shape(incr_sorting_score)[1] is not None and tf.shape(incr_sorting_score)[1]==1: incr_sorting_score = SelectWithDefault(nidx, incr_sorting_score, 0.)[:,0] tfssc = tf.where(nidx<0, 1e9, incr_sorting_score) #make sure the -1 end up at the end @@ -3092,7 +3104,7 @@ def call(self, inputs, training=None): ######## generic neighbours -class RaggedGravNet(LayerWithMetrics): +class RaggedGravNet(tf.keras.layers.Layer): def __init__(self, n_neighbours: int, n_dimensions: int, @@ -3106,6 +3118,8 @@ def __init__(self, use_dynamic_knn=True, debug = False, n_knn_bins=None, + _promptnames=None, #compatibility, does nothing + record_metrics=False, #compatibility, does nothing **kwargs): """ Call will return output features, coordinates, neighbor indices and squared distances from neighbors @@ -3133,7 +3147,7 @@ def __init__(self, #n_neighbours += 1 # includes the 'self' vertex assert n_neighbours > 1 assert not use_approximate_knn #not needed anymore. Exact one is faster by now - + self.n_neighbours = n_neighbours self.n_dimensions = n_dimensions self.n_filters = n_filters @@ -3152,9 +3166,12 @@ def __init__(self, self.input_feature_transform = tf.keras.layers.Dense(n_propagate, activation=feature_activation) with tf.name_scope(self.name + "/2/"): + s_kernel_initializer = 'glorot_uniform' + if coord_initialiser_noise is not None: + s_kernel_initializer = EyeInitializer(mean=0, stddev=coord_initialiser_noise) self.input_spatial_transform = tf.keras.layers.Dense(n_dimensions, #very slow turn on - kernel_initializer=EyeInitializer(mean=0, stddev=coord_initialiser_noise), + kernel_initializer=s_kernel_initializer, use_bias=False) with tf.name_scope(self.name + "/3/"): @@ -3170,27 +3187,18 @@ def build(self, input_shapes): self.input_feature_transform.build(input_shape) with tf.name_scope(self.name + "/2/"): - self.input_spatial_transform.build(input_shape) + if len(input_shapes) == 3: #extra coords + c_shape = [s for s in input_shape] + c_shape[-1] += input_shapes[2][-1] + self.input_spatial_transform.build(c_shape) + else: + self.input_spatial_transform.build(input_shape) with tf.name_scope(self.name + "/3/"): self.output_feature_transform.build((input_shape[0], self.n_prop_total + input_shape[1])) super(RaggedGravNet, self).build(input_shape) - def update_dynamic_radius(self, dist, training): - if not self.use_dynamic_knn or not self.trainable: - return - #update slowly, with safety margin - lindist = tf.sqrt(dist) - update = tf.reduce_max(lindist)*1.05 #can be inverted for performance TBI - mean_dist = tf.reduce_mean(lindist) - low_update = tf.where(update>2.,2.,update)#receptive field ends at 1. - update = tf.where(low_update>2.*mean_dist,low_update,2.*mean_dist)#safety setting to not loose all neighbours - update += 1e-3 - update = self.dynamic_radius + 0.05*(update-self.dynamic_radius) - updated_radius = tf.keras.backend.in_train_phase(update,self.dynamic_radius,training=training) - #print('updated_radius',updated_radius) - tf.keras.backend.update(self.dynamic_radius,updated_radius) def create_output_features(self, x, neighbour_indices, distancesq): allfeat = [] @@ -3214,8 +3222,11 @@ def priv_call(self, inputs, training=None): if row_splits.shape[0] is not None: tf.assert_equal(row_splits[-1], x.shape[0]) + x_coord = x + if len(inputs) == 3: + x_coord = tf.concat([inputs[2], x], axis=-1) - coordinates = self.input_spatial_transform(x) + coordinates = self.input_spatial_transform(x_coord) neighbour_indices, distancesq, sidx, sdist = self.compute_neighbours_and_distancesq(coordinates, row_splits, training) neighbour_indices = tf.reshape(neighbour_indices, [-1, self.n_neighbours]) #for proper output shape for keras distancesq = tf.reshape(distancesq, [-1, self.n_neighbours]) @@ -3254,8 +3265,6 @@ def compute_neighbours_and_distancesq(self, coordinates, row_splits, training): dist = tf.where(idx<0,0.,dist) - self.update_dynamic_radius(dist,training) - if self.return_self: return idx[:, 1:], dist[:, 1:], idx, dist return idx[:, 1:], dist[:, 1:], None, None @@ -3264,9 +3273,11 @@ def compute_neighbours_and_distancesq(self, coordinates, row_splits, training): def collect_neighbours(self, features, neighbour_indices, distancesq): f = None if self.sumwnorm: - f,_ = AccumulateKnnSumw(10.*distancesq, features, neighbour_indices) + f,_ = AccumulateKnnSumw(10.*distancesq, features, + neighbour_indices, mean_and_max=True) else: - f,_ = AccumulateKnn(10.*distancesq, features, neighbour_indices) + f,_ = AccumulateKnn(10.*distancesq, features, neighbour_indices, + mean_and_max=True) return f def get_config(self): @@ -3303,7 +3314,7 @@ def call(self, input): return input * att -class MultiAttentionGravNetAdd(LayerWithMetrics): +class MultiAttentionGravNetAdd(tf.keras.layers.Layer): def __init__(self, n_attention_kernels :int, **kwargs): @@ -3365,10 +3376,10 @@ def call(self, inputs): for di in range(len(self.kernel_coord_dense)): refcadd = self.kernel_coord_dense[di](feat) - for i in range(coord.shape[-1]): - meancoord = tf.reduce_mean(refcadd[:,i]) - self.add_prompt_metric(meancoord, self.name+'_coord_add_mean_'+str(di)+'_'+str(i)) - self.add_prompt_metric(tf.math.reduce_std(refcadd[:,i]-meancoord), self.name+'_coord_add_var_'+str(di)+'_'+str(i)) + #for i in range(coord.shape[-1]): + # meancoord = tf.reduce_mean(refcadd[:,i]) + # self.add_prompt_metric(meancoord, self.name+'_coord_add_mean_'+str(di)+'_'+str(i)) + # self.add_prompt_metric(tf.math.reduce_std(refcadd[:,i]-meancoord), self.name+'_coord_add_var_'+str(di)+'_'+str(i)) refcoord = refcadd + coord refcoord = tf.expand_dims(refcoord,axis=1)#V x 1 x C @@ -3548,6 +3559,7 @@ def __init__(self, n_feature_transformation, #=[32, 32, 32, 32, 4, 4], sumwnorm=False, activation='relu', + exp_distances = True, #use feat * exp(-distance) weighting, if not simple feat * distance **kwargs): super(DistanceWeightedMessagePassing, self).__init__(**kwargs) @@ -3555,6 +3567,7 @@ def __init__(self, self.sumwnorm = sumwnorm self.feature_tranformation_dense = [] self.activation = activation + self.exp_distances = exp_distances for i in range(len(self.n_feature_transformation)): with tf.name_scope(self.name + "/5/" + str(i)): self.feature_tranformation_dense.append(tf.keras.layers.Dense(self.n_feature_transformation[i], @@ -3579,6 +3592,7 @@ def compute_output_shape(self, inputs_shapes): def get_config(self): config = {'n_feature_transformation': self.n_feature_transformation, 'activation': self.activation, + 'exp_distances': self.exp_distances, 'sumwnorm':self.sumwnorm } base_config = super(DistanceWeightedMessagePassing, self).get_config() @@ -3603,9 +3617,15 @@ def create_output_features(self, x, neighbour_indices, distancesq): def collect_neighbours(self, features, neighbour_indices, distancesq): f=None if self.sumwnorm: - f,_ = AccumulateKnnSumw(10.*distancesq, features, neighbour_indices, mean_and_max=True) + if self.exp_distances: + f,_ = AccumulateKnnSumw(10.*distancesq, features, neighbour_indices, mean_and_max=True) + else: + f,_ = AccumulateLinKnnSumw(distancesq, features, neighbour_indices, mean_and_max=True) else: - f,_ = AccumulateKnn(10.*distancesq, features, neighbour_indices) + if self.exp_distances: + f,_ = AccumulateKnn(10.*distancesq, features, neighbour_indices) + else: + f,_ = AccumulateLinKnn(distancesq, features, neighbour_indices) return f def call(self, inputs): @@ -3980,6 +4000,8 @@ def __init__(self,**kwargs): super(FlatNeighbourFeatures, self).__init__(**kwargs) def call(self, inputs): + + assert len(inputs) == 2 feat,nidx = inputs n_feat = SelectWithDefault(nidx,feat,0.) # [V x K x F] diff --git a/modules/Layers.py b/modules/Layers.py index 73b86c9e..c76e0793 100644 --- a/modules/Layers.py +++ b/modules/Layers.py @@ -264,6 +264,10 @@ from GravNetLayersRagged import LocalGravNetAttention global_layers_list['LocalGravNetAttention']=LocalGravNetAttention + +from GravNetLayersRagged import FlatNeighbourFeatures +global_layers_list['FlatNeighbourFeatures']=FlatNeighbourFeatures + ### odd debug layers from DebugLayers import PlotCoordinates global_layers_list['PlotCoordinates']=PlotCoordinates @@ -360,9 +364,71 @@ import tensorflow as tf +class GroupSortActivation(tf.keras.layers.Layer): + + def compute_output_shape(self, input_shapes): + return input_shapes + + def call(self, inputs): + out = tf.sort(inputs, axis=-1) + return tf.reshape(out, tf.shape(inputs)) + +global_layers_list['GroupSortActivation']=GroupSortActivation +def layernorm(x, return_norm=False): + x = x - tf.reduce_mean(x,axis=-1, keepdims=True) + norm = tf.reduce_sum(x**2, axis=-1,keepdims=True) + norm = tf.sqrt(norm+1e-6) + if return_norm: + x = tf.concat([x / norm * tf.sqrt(tf.cast(x.shape[-1],'float32')), norm], axis=-1) + else: + x = x / norm * tf.sqrt(tf.cast(x.shape[-1],'float32')) + return x + +global_layers_list['layernorm']= layernorm #convenience + +class SphereActivation(tf.keras.layers.Layer): + ''' + a layer norm that can also return the norm + ''' + + def __init__(self,return_norm = False, **kwargs): + super(SphereActivation, self).__init__(**kwargs) + self.return_norm = return_norm + + def get_config(self): + config = {'return_norm': self.return_norm} + base_config = super(SphereActivation, self).get_config() + return dict(list(base_config.items()) + list(config.items() )) + + def call(self, x): + if not self.return_norm: + return layernorm(x, False) + else: + out = layernorm(x, True) + return out[...,:x.shape[-1]], out[...,x.shape[-1]:x.shape[-1]+1] + +global_layers_list['SphereActivation']=SphereActivation + +class Multi(tf.keras.layers.Layer): + + def call(self, inputs): + assert len(inputs)==2 + x,y = inputs + return x*y #but with broadcasting + +global_layers_list['Multi']=Multi +class Sqrt(tf.keras.layers.Layer): + + def compute_output_shape(self, input_shapes): + return input_shapes + + def call(self, x): + return tf.sqrt(x + 1e-6) + +global_layers_list['Sqrt']=Sqrt class SplitFeatures(Layer): def __init__(self,**kwargs): diff --git a/modules/LossLayers.py b/modules/LossLayers.py index e6ebe3ae..cc8a5893 100644 --- a/modules/LossLayers.py +++ b/modules/LossLayers.py @@ -265,7 +265,7 @@ def call(self, inputs): self.maybe_print_loss(lossval,now) - lossval = tf.debugging.check_numerics(lossval, self.name+" produced inf or nan.") + #lossval = tf.debugging.check_numerics(lossval, self.name+" produced inf or nan.") #this can happen for empty batches. If there are deeper problems, check in the losses themselves #lossval = tf.where(tf.math.is_finite(lossval), lossval ,0.) if not self.return_lossval: @@ -289,10 +289,11 @@ def maybe_print_loss(self,lossval,stime=None): if self.print_loss: if hasattr(lossval, 'numpy'): print(self.name, 'loss', lossval.numpy()) + tf.print(self.name, 'loss', lossval.numpy()) else: tf.print(self.name, 'loss', lossval) + print(self.name, 'loss', lossval) - if self.print_batch_time or self.record_metrics: now = tf.timestamp() prev = self.time @@ -353,6 +354,52 @@ def loss(self, inputs): +class LLObjectValuePenalty(LLValuePenalty): + + def __init__(self, + noise_scale = 10., + **kwargs): + ''' + Simple value penalty loss, tries to keep values around default using simple + L2 regularisation; normalises per object + + inputs: + - value to penalise + - t_idx + + returns input + ''' + self.noise_scale = noise_scale + super(LLObjectValuePenalty, self).__init__(**kwargs) + + def get_config(self): + config = {'noise_scale': self.noise_scale} + base_config = super(LLObjectValuePenalty, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def loss(self, inputs): + assert len(inputs) == 2 + val, tidx = inputs + + Msel,_,_ = CreateMidx(tidx, calc_m_not=False) + + if Msel is None or tf.shape(Msel)[0] == None: + return 0. + + val_k_m = SelectWithDefault(Msel, val, self.default) #K x V-obj x 1 + mask_k_m = SelectWithDefault(Msel, tf.ones_like(val), 0.) #K x V-obj x 1 + vloss = (self.default - val_k_m) ** 2 + vloss = tf.math.divide_no_nan(tf.reduce_sum(vloss, axis=1), + tf.reduce_sum(mask_k_m, axis=1) + 1e-6) + vloss = tf.reduce_mean(vloss) + + #now the noise + is_noise = tf.cast( tidx < 0, dtype='float32' ) + vloss += self.noise_scale * tf.math.divide_no_nan(tf.reduce_sum(is_noise * (self.default - val)**2), + tf.reduce_sum(is_noise) + 1e-6) + return vloss + + class CreateTruthSpectatorWeights(tf.keras.layers.Layer): def __init__(self, @@ -1141,7 +1188,7 @@ def _rs_loop(self,coords, tidx, specweight, energy): return distloss+reploss, distloss, reploss - def raw_loss(self,acoords, atidx, aspecw, aenergy, rs, downsample): + def raw_loss(self,acoords, atidx, aspecw, aenergy, rs): lossval = tf.zeros_like(acoords[0,0]) reploss = tf.zeros_like(acoords[0,0]) @@ -1157,8 +1204,14 @@ def raw_loss(self,acoords, atidx, aspecw, aenergy, rs, downsample): specw = aspecw[rs[i]:rs[i+1]] energy = aenergy[rs[i]:rs[i+1]] - if downsample>0 and downsample < coords.shape[0]: - sel = tf.random.uniform(shape=(downsample,), minval=0, maxval=coords.shape[0]-1, dtype=tf.int32) + if self.downsample > 0:# and self.downsample < coords.shape[0]: + sel = tf.range(coords.shape[0]) + sel = tf.random.shuffle(sel) + + length = tf.reduce_min([tf.constant(self.downsample), tf.shape(coords)[0]]) + + sel = sel[:length] + #sel = tf.random.uniform(shape=(self.downsample,), minval=0, maxval=coords.shape[0]-1, dtype=tf.int32) sel = tf.expand_dims(sel,axis=1) coords = tf.gather_nd(coords, sel) tidx = tf.gather_nd(tidx, sel) @@ -1195,7 +1248,7 @@ def loss(self, inputs): # return zero_loss lossval,distloss, reploss = self.raw_loss( - coords, tidx, specw, energy, rs, self.downsample) + coords, tidx, specw, energy, rs) lossval = tf.where(tf.math.is_finite(lossval), lossval, 0.)#DEBUG @@ -1824,12 +1877,14 @@ def __init__(self, super(LLBasicObjectCondensation, self).__init__(**kwargs) - from object_condensation import Basic_OC_per_sample, PushPull_OC_per_sample, PreCond_kNNOC_per_sample, PreCond_OC_per_sample + from object_condensation import Basic_OC_per_sample, PushPull_OC_per_sample, Hinge_OC_per_sample, PreCond_OC_per_sample impl = Basic_OC_per_sample if implementation == 'pushpull': impl = PushPull_OC_per_sample if implementation == 'precond': impl = PreCond_OC_per_sample + if implementation == 'hinge': + impl = Hinge_OC_per_sample self.oc_loss_object = OC_loss( loss_impl = impl, @@ -1915,6 +1970,7 @@ def __init__(self, *, energy_loss_weight=1., div_repulsion=False, dynamic_payload_scaling_onset=-0.005, beta_push=0., + implementation = '', **kwargs): """ Read carefully before changing parameters @@ -1967,9 +2023,24 @@ def __init__(self, *, energy_loss_weight=1., if huber_energy_scale>0 and alt_energy_loss: raise ValueError("huber_energy_scale>0 and alt_energy_loss exclude each other") + + from object_condensation import Basic_OC_per_sample, PushPull_OC_per_sample, Hinge_OC_per_sample, Hinge_Manhatten_OC_per_sample, PreCond_OC_per_sample + impl = Basic_OC_per_sample + if implementation == 'pushpull': + impl = PushPull_OC_per_sample + if implementation == 'precond': + impl = PreCond_OC_per_sample + if implementation == 'hinge': + impl = Hinge_OC_per_sample + if implementation == 'hinge_manhatten': + impl = Hinge_Manhatten_OC_per_sample + self.implementation = implementation + + #configuration here, no need for all that stuff below #as far as the OC part is concerned (still config for payload though) self.oc_loss_object = OC_loss( + loss_impl = impl, q_min= q_min, s_b=s_b, use_mean_x=use_average_cc_pos, @@ -2050,6 +2121,9 @@ def calc_energy_correction_factor_loss(self, t_energy, t_dep_energies, pred_energy,pred_energy_low_quantile,pred_energy_high_quantile, return_concat=False): + if self.energy_loss_weight == 0.: + return pred_energy**2 + pred_energy_low_quantile**2 + pred_energy_high_quantile**2 + ediff = (t_energy - pred_energy*t_dep_energies)/tf.sqrt(tf.abs(t_energy)+1e-3) ediff = tf.debugging.check_numerics(ediff, "eloss ediff") @@ -2245,10 +2319,11 @@ def loss(self, inputs): if is_spectator is None: is_spectator = tf.zeros_like(pred_beta) - full_payload = tf.debugging.check_numerics(full_payload,"full_payload has nans of infs") - pred_ccoords = tf.debugging.check_numerics(pred_ccoords,"pred_ccoords has nans of infs") - energy_weights = tf.debugging.check_numerics(energy_weights,"energy_weights has nans of infs") - pred_beta = tf.debugging.check_numerics(pred_beta,"beta has nans of infs") + #just go with it + #full_payload = tf.debugging.check_numerics(full_payload,"full_payload has nans of infs") + #pred_ccoords = tf.debugging.check_numerics(pred_ccoords,"pred_ccoords has nans of infs") + #energy_weights = tf.debugging.check_numerics(energy_weights,"energy_weights has nans of infs") + #pred_beta = tf.debugging.check_numerics(pred_beta,"beta has nans of infs") #safe guards with tf.control_dependencies( [tf.assert_equal(rowsplits[-1], pred_beta.shape[0]), @@ -2395,7 +2470,8 @@ def get_config(self): 'super_attraction':self.super_attraction, 'div_repulsion' : self.div_repulsion, 'dynamic_payload_scaling_onset': self.dynamic_payload_scaling_onset, - 'beta_push': self.beta_push + 'beta_push': self.beta_push, + 'implementation': self.implementation } base_config = super(LLFullObjectCondensation, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/modules/MetricsLayers.py b/modules/MetricsLayers.py index 04437604..38fbf856 100644 --- a/modules/MetricsLayers.py +++ b/modules/MetricsLayers.py @@ -167,9 +167,11 @@ def __init__(self, **kwargs): def metrics_call(self, inputs): - #tren = None + istrack = None if len(inputs)==5: gsel,tidx,ten,rs,srs = inputs + if len(inputs)==6: + gsel,tidx,ten,istrack,rs,srs = inputs #tf.assert_equal(tidx.shape,ten.shape)#safety alltruthcount = None @@ -181,6 +183,7 @@ def metrics_call(self, inputs): return stidx, sten = tf.constant([[0]],dtype='int32'), tf.constant([[0.]],dtype='float32') + n_track_before, n_track_after = tf.constant([[0.]],dtype='float32'),tf.constant([[0.]],dtype='float32') if self.active: stidx, sten = SelIdx.raw_call(gsel,[tidx,ten]) @@ -198,7 +201,12 @@ def metrics_call(self, inputs): seltruthcount = u.shape[0] else: seltruthcount += u.shape[0] - + + if istrack is not None: + n_track_before = tf.reduce_sum(istrack) + n_track_after = SelIdx.raw_call(gsel,[istrack]) + n_track_after = tf.reduce_sum(n_track_after) + nonoisecounts_bef = tf.concat(nonoisecounts_bef,axis=0) nonoisecounts_after = tf.concat(nonoisecounts_after,axis=0) @@ -222,14 +230,20 @@ def metrics_call(self, inputs): lostenergies = ue[c<2] #print(lostenergies) - self.add_prompt_metric(tf.reduce_mean(nonoisecounts_bef),self.name+'_hits_pobj_bef_mean') + self.add_prompt_metric(tf.reduce_mean(tf.cast(nonoisecounts_bef,'float32')),self.name+'_hits_pobj_bef_mean') self.add_prompt_metric(tf.reduce_max(nonoisecounts_bef),self.name+'_hits_pobj_bef_max') - self.add_prompt_metric(tf.reduce_mean(nonoisecounts_after),self.name+'_hits_pobj_after_mean') + self.add_prompt_metric(tf.reduce_mean(tf.cast(nonoisecounts_after,'float32')),self.name+'_hits_pobj_after_mean') self.add_prompt_metric(tf.reduce_max(nonoisecounts_after),self.name+'_hits_pobj_after_max') - self.add_prompt_metric(tf.reduce_mean(lostenergies),self.name+'_lost_energy_mean') - self.add_prompt_metric(tf.reduce_max(lostenergies),self.name+'_lost_energy_max') + l_em = tf.reduce_mean(lostenergies) + l_em = tf.where(tf.math.is_finite(l_em),l_em, 0.) + + l_ema = tf.reduce_max(lostenergies) + l_ema = tf.where(tf.math.is_finite(l_ema),l_ema, 0.) + + self.add_prompt_metric(l_em,self.name+'_lost_energy_mean') + self.add_prompt_metric(l_ema,self.name+'_lost_energy_max') self.add_prompt_metric(tot_lost_en_sum,self.name+'_lost_energy_sum') reduced_to_fraction = tf.cast(srs[-1],dtype='float32')/tf.cast(rs[-1],dtype='float32') @@ -239,6 +253,11 @@ def metrics_call(self, inputs): no_noise_hits_aft = tf.cast(tf.math.count_nonzero(stidx+1) ,dtype='float32') self.add_prompt_metric(no_noise_hits_aft/no_noise_hits_bef,self.name+'_no_noise_reduction') + if istrack is not None: + self.add_prompt_metric(n_track_before,self.name+'_tracks_bef') + self.add_prompt_metric(n_track_after,self.name+'_tracks_after') + + diff --git a/modules/accknn_op.py b/modules/accknn_op.py index 14780ac4..03a88c8a 100644 --- a/modules/accknn_op.py +++ b/modules/accknn_op.py @@ -14,6 +14,8 @@ _accknn_op = tf.load_op_library('accumulate_knn.so') _accknn_grad_op = tf.load_op_library('accumulate_knn_grad.so') +if gl.acc_ops_use_tf_gradients: + print('accknn_op: warning, running with less memory efficient TF gradients.') def AccumulateLinKnn(weights, features, indices, mean_and_max=True, force_tf=False): diff --git a/modules/callbacks.py b/modules/callbacks.py index 2d0ae12e..ec7417a4 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -2,6 +2,7 @@ from DeepJetCore.training.DeepJet_callbacks import PredictCallback from multiprocessing import Process import numpy as np +import tensorflow as tf from OCHits2Showers import process_endcap, OCGatherEnergyCorrFac from datastructures import TrainData_NanoML @@ -12,6 +13,9 @@ from Layers import DictModel from plotting_tools import publish, shuffle_truth_colors +from DebugLayers import _DebugPlotBase +from DeepJetCore import TrainData +from DeepJetCore.dataPipeline import TrainDataGenerator class plotDuringTrainingBase(PredictCallback): @@ -341,12 +345,16 @@ def _make_plot(self, counter, feat, predicted, truth): td = TrainData_NanoML() preddict=predicted + rs = feat[-1] + + if 'sel_idx' in predicted.keys(): + feat = [tf.gather_nd(f, predicted['sel_idx']).numpy() for f in feat if len(f.shape)>1] cdata=td.createTruthDict(feat) cdata['predBeta'] = preddict['pred_beta'] cdata['predCCoords'] = preddict['pred_ccoords'] cdata['predD'] = preddict['pred_dist'] - rs = feat[-1]#last one has to be row splits + #last one has to be row splits # this will not work, since it will be adapted by batch, and not anymore the right tow splits #rs = preddict['row_splits'] @@ -606,6 +614,111 @@ def _on_train_batch_end(self, batch, logs=None): +class NanSweeper(tf.keras.callbacks.Callback): + ''' + Slight extension of the normal checkpoint to multiple checkpoints per epoch + ''' + + def __init__(self): + super().__init__() + self.saved_weights = None + + def on_batch_end(self,batch,logs={}): + + mw = self.model.get_weights() + + if self.saved_weights is None: + self.saved_weights = [] + for w in mw: + w = tf.where(tf.math.is_finite(w),w,tf.random.normal(w.shape, stddev=1e-3)) + self.saved_weights.append(w) + return + nw = [] + n_nans = 0 + for w,sw in zip(mw, self.saved_weights): + nw.append( tf.where( tf.math.is_finite(w), w, sw ).numpy()) + n_nans += tf.reduce_sum( + tf.cast(tf.logical_not(tf.math.is_finite(w)),'int32') + ).numpy() + + if n_nans>0: + print("NanSweeper: removed", n_nans, "NaNs or Infs") + #find them: + for w in self.model.weights: + if np.all(np.isfinite(w.numpy())): + continue + print(w.name, 'had NaNs') + + self.model.set_weights(nw) + + + self.saved_weights = nw + + + + +class DebugPlotRunner(tf.keras.callbacks.Callback): + ''' + Slight extension of the normal checkpoint to multiple checkpoints per epoch + ''' + + def __init__(self, + sample : str, + plot_frequency=500, + adapt_outname = '', + use_event=0): + + super().__init__() + self.plot_frequency = plot_frequency + self.sample = sample + self.changed_layers=[] + self.adapt_outname = adapt_outname + self.counter = 0 + + #load the sample + assert sample[-6:] == '.djctd' + + #load on event + td = TrainData() + td.readFromFile(sample) + td.skim(use_event) + self.data = td.transferFeatureListToNumpy(False) + + + def _trigger_plots(self): + + self.changed_layers=[{}] + for l in self.model.layers: + if isinstance(l, _DebugPlotBase): + self.changed_layers.append( + {'l': l, 'o':l.outdir } + ) + l.triggered = True + l.outdir = l.outdir + self.adapt_outname + + def _set_model_back(self): + for l in self.changed_layers: + l['l'].triggered = False + l['l'].outdir = l['o'] + + def on_batch_end(self,batch,logs={}): + + #check if it should run + if self.counter < self.plot_frequency: + self.counter += 1 + return + + self.counter = 0 + + self._trigger_plots() + + #run model + _ = self.model(self.data) + + self._set_model_back() + + + diff --git a/modules/compiled/push_knn_kernel.cu.cc b/modules/compiled/push_knn_kernel.cu.cc index fd6c0733..a71d6d21 100644 --- a/modules/compiled/push_knn_kernel.cu.cc +++ b/modules/compiled/push_knn_kernel.cu.cc @@ -41,7 +41,9 @@ void push_knn_kernel( int n_vert, int n_neigh, - int n_feat) { + int n_feat, + + bool atomic = true) { //switch off with care! //parallelise over neighbours and features - no race conditions @@ -53,12 +55,16 @@ void push_knn_kernel( int nidx = d_idxs[I2D(i_v,i_n,n_neigh)]; if(nidx<0) return; + if(nidx>=n_vert) asm("trap;"); //throw error float f = d_feat[I2D(i_v,i_f,n_feat)]; float w = d_weights[I2D(i_v,i_n,n_neigh)]; - atomicAdd(&d_out_feat[I2D(nidx,i_f,n_feat)] , f*w); + if(atomic) + atomicAdd(&d_out_feat[I2D(nidx,i_f,n_feat)] , f*w); + else + d_out_feat[I2D(nidx,i_f,n_feat)] += f*w; } @@ -87,6 +93,8 @@ struct PushKnnOpFunctor { cudaDeviceSynchronize(); + bool atomic = true; + //this should keep the atomic reasonably ok grid_and_block par( n_feat, 32, @@ -98,6 +106,7 @@ struct PushKnnOpFunctor { par = grid_and_block(n_feat, 128, n_vert, 1, n_neigh, 1);//no atomic *within* one block, still can be globally! + atomic = false; } else if(n_feat >= 64){ //32 and 64 are also rather standard par = grid_and_block( @@ -111,12 +120,12 @@ struct PushKnnOpFunctor { n_vert, 4, n_neigh, 2); } - if(n_feat < 2){ //this is for energy push, also standard + else if(n_feat < 2){ //this is for energy push, also standard par = grid_and_block(n_feat, 1, n_vert, 128, n_neigh, 2); } - if(n_feat <32){ //this is for energy push, also standard + else if(n_feat <32){ //this is for energy push, also standard par = grid_and_block(n_feat, 8, n_vert, 12, n_neigh, 2); @@ -132,7 +141,8 @@ struct PushKnnOpFunctor { n_vert, n_neigh, - n_feat); + n_feat, + atomic); cudaDeviceSynchronize(); } diff --git a/modules/compiled/tests/test_push_layer.py b/modules/compiled/tests/test_push_layer.py new file mode 100755 index 00000000..b60457e0 --- /dev/null +++ b/modules/compiled/tests/test_push_layer.py @@ -0,0 +1,158 @@ + +import tensorflow as tf +from binned_select_knn_op import BinnedSelectKnn +import numpy as np +from GraphCondensationLayers import PushUp, GraphCondensation + +import time +from push_knn_op import PushKnn +from push_knn_op import _tf_push_knn as tf_push_knn + +def make_data(nvert, nk, nf): + + f = tf.constant(np.random.rand(nvert,nf),dtype='float32') + w = tf.constant(np.random.rand(nvert,nk),dtype='float32') + + c = tf.constant(np.random.rand(nvert,3),dtype='float32') #just to get neighbours + nidx, _ = BinnedSelectKnn(nk+1, c, tf.constant([0,nvert],dtype='int32')) + + return tf.constant(f),tf.constant(w),nidx[:,1:] + + +def simple_replacement_sum(f,trans,weight = None, select=False, mode='sum', add_self=False): + #most simple way of implementing it + out = (f*0.).numpy() + wsum = (f[:,0:1]*0.).numpy() + if weight is None: + weight = f[:,0:1]*0. + 1. + + if add_self: + for i_f in range(out.shape[1]): + for i_v in range(f.shape[0]): + if trans['nidx_down'][i_v,0] < 0:#up + out[i_v,i_f] += f[i_v,i_f] * weight[i_v] * trans['weights_down'][i_v,0] + wsum[i_v] += weight[i_v] * trans['weights_down'][i_v,0] + + for i_f in range(out.shape[1]): + for i_v in range(f.shape[0]): + for i_n in range(trans['nidx_down'].shape[1]): + nidx = trans['nidx_down'][i_v,i_n] + if nidx < 0: + continue + out[nidx,i_f] += f[i_v,i_f] * weight[i_v] * trans['weights_down'][i_v,i_n] + wsum[nidx] += weight[i_v] * trans['weights_down'][i_v,i_n] + + if mode == 'mean': + for i_f in range(out.shape[1]): + for i_v in range(f.shape[0]): + if wsum[i_v]: + out[i_v,i_f] /= wsum[i_v] / out.shape[1] #divide by nfeat that was added above + + + out = tf.gather_nd(out, trans['sel_idx_up']) + return out + + +def simple_data(randomise = False): + f = tf.constant([ + [1., 1./2., 1./4], + [2., 2./2., 2./4], + [3., 3./2., 3./4], + [4., 4./2., 4./4], + [10.,11.,12.] + ]) + + wf = tf.constant([ + [2], + [1.5], + [3], + [0.5], + [3.] + ],dtype='float32') + + w = tf.constant([ + [1.5, 2.], + [1., 1.], + [2., 1.], + [20., 1.], + [2., 1.] + ]) + + nidx = tf.constant([ + [1, 3], + [-1, -1], + [1, -1], + [-1, -1], + [-1, -1] + ]) + + if randomise: + f = tf.constant(np.random.rand(*list(f.shape)),dtype='float32')+1. + wf = tf.constant(np.random.rand(*list(wf.shape)),dtype='float32')+0.2 + w = tf.constant(np.random.rand(*list(w.shape)),dtype='float32')+0.1 #make it all numerically stable + ''' + 'rs_down', + 'rs_up', + 'nidx_down', + 'distsq_down', #in case it's needed + 'sel_idx_up', # -> can also be used to scatter + 'weights_down' + ''' + trans = GraphCondensation() + trans['rs_down'] = tf.constant([0,4],dtype='int32') + trans['rs_up'] = tf.constant([0,2],dtype='int32') + trans['nidx_down'] = nidx + trans['distsq_down'] = tf.abs(w) + trans['sel_idx_up'] = tf.constant([[1],[3],[4]],dtype='int32') + trans['weights_down'] = w + + + return f, tf.abs(wf), trans + +f, wf, trans = simple_data(True) + +#print(PushUp(mode='sum')(f,trans)) +#simple_replacement_sum(f,trans) +# +#print(PushUp(mode='sum',add_self=True)(f,trans)) +#simple_replacement_sum(f,trans,add_self=True) +# + +#print(PushUp(mode='sum')(f,trans, weight = wf)) +#simple_replacement_sum(f,trans, weight = wf) +# +#print(PushUp(mode='sum',add_self=True)(f,trans, weight = wf)) +#simple_replacement_sum(f,trans, weight = wf,add_self=True) +# +#exit() + +#print(PushUp(mode='mean')(f,trans)) +#simple_replacement_sum(f,trans,mode='mean') +# +#f, wf, trans = simple_data(randomise=True) +#print(PushUp(mode='mean')(f,trans)) +#simple_replacement_sum(f,trans,mode='mean') + +#exit() +pu = PushUp(mode='mean')(f,trans,weight = wf) +spu = simple_replacement_sum(f,trans,weight = wf,mode='mean') +print(pu) +print(spu) +print(pu-spu) + +pu = PushUp(mode='mean',add_self=True)(f,trans,weight = wf) +spu = simple_replacement_sum(f,trans,weight = wf,mode='mean',add_self=True) + +print(pu) +print(spu) +print(pu-spu) + + + + + + + + + + diff --git a/modules/datastructures/TrainData_crilin.py b/modules/datastructures/TrainData_crilin.py index 040cd4f5..d212099b 100644 --- a/modules/datastructures/TrainData_crilin.py +++ b/modules/datastructures/TrainData_crilin.py @@ -12,7 +12,7 @@ import gzip import pandas as pd -from datastructures import TrainData_NanoML +from datastructures.TrainData_NanoML import TrainData_NanoML class TrainData_crilin(TrainData_NanoML): diff --git a/modules/datastructures/TrainData_fcc.py b/modules/datastructures/TrainData_fcc.py new file mode 100644 index 00000000..9cf0fdb0 --- /dev/null +++ b/modules/datastructures/TrainData_fcc.py @@ -0,0 +1,173 @@ + + + +from DeepJetCore.TrainData import TrainData, fileTimeOut +from DeepJetCore import SimpleArray +import numpy as np +import uproot3 as uproot +import awkward as ak1 +from numba import jit +import gzip +import os +import pickle + +#@jit(nopython=False) +def truth_loop(link_list :list, + t_dict:dict, + part_p_list :list, + ): + + nevts = len(link_list) + for ie in range(nevts):#event + nhits = len(link_list[ie]) + for ih in range(nhits): + idx = -1 + mom = 0. + if link_list[ie][ih] >= 0: + idx = link_list[ie][ih] + mom = part_p_list[ie][idx] + + t_dict['t_idx'].append([idx]) + t_dict['t_energy'].append([mom]) + + t_dict['t_pos'].append([0.,0.,0.]) + t_dict['t_time'].append([0.]) + t_dict['t_pid'].append([0.,0.,0.,0.,0.,0.]) + t_dict['t_spectator'].append([0.]) + t_dict['t_fully_contained'].append([1.]) + t_dict['t_rec_energy'].append([mom]) # THIS WILL NEED TO BE ADJUSTED + t_dict['t_is_unique'].append([1]) #does not matter really + + + return t_dict + + +class TrainData_fcc(TrainData): + + def branchToFlatArray(self, b, return_row_splits=False, dtype='float32'): + + a = b.array() + nevents = a.shape[0] + rowsplits = [0] + + for i in range(nevents): + rowsplits.append(rowsplits[-1] + a[i].shape[0]) + + rowsplits = np.array(rowsplits, dtype='int64') + + if return_row_splits: + return np.expand_dims(np.array(a.flatten(),dtype=dtype), axis=1),np.array(rowsplits, dtype='int64') + else: + return np.expand_dims(np.array(a.flatten(),dtype=dtype), axis=1) + + def convertFromSourceFile(self, filename, weighterobjects, istraining, treename="events"): + + fileTimeOut(filename, 10)#wait 10 seconds for file in case there are hiccups + tree = uproot.open(filename)[treename] + + ''' + + hit_x, hit_y, hit_z: the spatial coordinates of the voxel centroids that registered the hit + hit_dE: the energy registered in the voxel (signal + BIB noise) + recHit_dE: the 'reconstructed' hit energy, i.e. the energy deposited by signal only + evt_dE: the total energy deposited by the signal photon in the calorimeter + evt_ID: an int label for each event -only for bookkeeping, should not be needed + isSignal: a flag, -1 if only BIB noise, 0 if there is also signal hit deposition + + ''' + + hit_x, rs = self.branchToFlatArray(tree["hit_x"], True) + hit_y = self.branchToFlatArray(tree["hit_y"]) + hit_z = self.branchToFlatArray(tree["hit_z"]) + hit_t = self.branchToFlatArray(tree["hit_t"]) + hit_e = self.branchToFlatArray(tree["hit_e"]) + hit_theta = self.branchToFlatArray(tree["hit_theta"]) + + + zerosf = 0.*hit_e + + print('hit_e',hit_e) + hit_e = np.where(hit_e<0., 0., hit_e) + + + farr = SimpleArray(np.concatenate([ + hit_e, + zerosf, + zerosf, #indicator if it is track or not + zerosf, + hit_theta, + hit_x, + hit_y, + hit_z, + zerosf, + hit_t + ], axis=-1), rs,name="recHitFeatures") + + + + # create truth + hit_genlink = tree["hit_genlink0"].array() + part_p = tree["part_p"].array() + + t = { + 't_idx' : [], #names are optional + 't_energy' : [], + 't_pos' : [], #three coordinates + 't_time' : [] , + 't_pid' : [] , #6 truth classes + 't_spectator' : [], + 't_fully_contained' : [], + 't_rec_energy' : [], + 't_is_unique' : [] + } + + #do this with numba + t = truth_loop(hit_genlink.tolist(), + t, + part_p.tolist(), + ) + + for k in t.keys(): + if k == 't_idx' or k == 't_is_unique': + t[k] = np.array(t[k], dtype='int32') + else: + t[k] = np.array(t[k], dtype='float32') + t[k] = SimpleArray(t[k], rs,name=k) + + return [farr, + t['t_idx'], t['t_energy'], t['t_pos'], t['t_time'], + t['t_pid'], t['t_spectator'], t['t_fully_contained'], + t['t_rec_energy'], t['t_is_unique'] ],[], [] + + + + def writeOutPrediction(self, predicted, features, truth, weights, outfilename, inputfile): + outfilename = os.path.splitext(outfilename)[0] + '.bin.gz' + # print("hello", outfilename, inputfile) + + outdict = dict() + outdict['predicted'] = predicted + outdict['features'] = features + outdict['truth'] = truth + + print("Writing to ", outfilename) + with gzip.open(outfilename, "wb") as mypicklefile: + pickle.dump(outdict, mypicklefile) + print("Done") + + def writeOutPredictionDict(self, dumping_data, outfilename): + ''' + this function should not be necessary... why break with DJC standards? + ''' + if not str(outfilename).endswith('.bin.gz'): + outfilename = os.path.splitext(outfilename)[0] + '.bin.gz' + + with gzip.open(outfilename, 'wb') as f2: + pickle.dump(dumping_data, f2) + + def readPredicted(self, predfile): + with gzip.open(predfile) as mypicklefile: + return pickle.load(mypicklefile) + + + diff --git a/modules/model_blocks.py b/modules/model_blocks.py index 137fd800..880a61ba 100644 --- a/modules/model_blocks.py +++ b/modules/model_blocks.py @@ -6,7 +6,7 @@ import tensorflow as tf from Initializers import EyeInitializer from GravNetLayersRagged import CondensateToIdxs, EdgeCreator -from Layers import SplitFeatures +from Layers import SplitFeatures, FlatNeighbourFeatures, Sqrt from datastructures.TrainData_NanoML import n_id_classes @@ -36,8 +36,6 @@ def create_outputs(x, n_ccoords=3, pred_beta = Dense(1, activation='sigmoid',name = name_prefix+'_beta')(x) pred_ccoords = Dense(n_ccoords, - #this initialisation is much better than standard glorot - kernel_initializer=EyeInitializer(stddev=0.001), use_bias=False, name = name_prefix+'_clustercoords' )(x) #bias has no effect @@ -1577,7 +1575,7 @@ def intermediate_condensation( from GraphCondensationLayers import MLGraphCondensationMetrics, LLGraphCondensationScore, LLGraphCondensationEdges from DebugLayers import PlotGraphCondensation, PlotGraphCondensationEfficiency from LossLayers import LLValuePenalty -from Layers import CheckNaN +from Layers import CheckNaN, SphereActivation def pre_graph_condensation( orig_inputs, @@ -1610,8 +1608,14 @@ def pre_graph_condensation( energy = orig_inputs['rechit_energy'] is_track = orig_inputs['is_track'] - x = ConditionalScaledGooeyBatchNorm( - name=name+'_cond_batchnorm', + + x = ScaledGooeyBatchNorm2( + name=name+'_cond_batchnorm_a', + record_metrics = record_metrics)([x, is_track]) + + x = ScaledGooeyBatchNorm2( + name=name+'_cond_batchnorm_b', + invert_condition = True, record_metrics = record_metrics)([x, is_track]) x_in = x @@ -1819,6 +1823,7 @@ def mini_pre_graph_condensation( low_energy_cut = 2., publish=None, dynamic_spectators=True, + coords = None, first_call=True): activation = 'elu' @@ -1834,7 +1839,8 @@ def mini_pre_graph_condensation( orig_inputs[k] = CheckNaN(name=name+'_pre_check_'+k)(orig_inputs[k]) x = orig_inputs['features'] # coords - coords = orig_inputs['prime_coords'] + if coords is None: + coords = orig_inputs['prime_coords'] rs = orig_inputs['row_splits'] energy = orig_inputs['rechit_energy'] is_track = orig_inputs['is_track'] @@ -2139,6 +2145,210 @@ def intermediate_graph_condensation( return trans_a, out_truth, sum_energy +def tiny_pc_pool( + orig_inputs, + debug_outdir='', + trainable=False, + name='pre_graph_pool', + debugplots_after=-1, + record_metrics=True, + reduction_target = 0.05, + K_loss = 48, + low_energy_cut_target = 1.0, + first_embed = True, + coords = None, + publish=None, + dmp_steps=[8,8,8,8], + dmp_compress = 32, + K_nn = 16, + K_gp = 5, + is_second=False, + new_format=True): + ''' + This function needs pre-processed input (from condition_input) + ''' + + edge_dense = [16] + edge_pre_nodes = 32 + dwmp_activation = 'elu' + + if is_second: + dmp_steps = [32,32,32,32] + dmp_compress = 64 + low_energy_cut_target = 0.5 + K_nn = 32 + K_gp = 8 + first_embed = False + edge_dense = [32,24,16] + edge_pre_nodes = 64 + dwmp_activation = 'tanh' #more smooth + + + ## gather inputs and norm + + x = orig_inputs['features'] # coords + if coords is None: + coords = orig_inputs['prime_coords'] + coords = ScaledGooeyBatchNorm2( + name = name+'_coords_batchnorm', trainable=trainable, + fluidity_decay=1e-2,#can freeze quickly + )(coords) + + coords = PlotCoordinates(plot_every=debugplots_after, + outdir=debug_outdir,name=name+'_coords', + publish=publish)( + [coords, + orig_inputs['rechit_energy'], + orig_inputs['t_idx'],orig_inputs['row_splits']]) + + rs = orig_inputs['row_splits'] + energy = orig_inputs['rechit_energy'] + is_track = orig_inputs['is_track'] + + + x = ScaledGooeyBatchNorm2( + name=name+'_cond_batchnorm_a', trainable=trainable)([x, is_track]) + + x = ScaledGooeyBatchNorm2( + name=name+'_cond_batchnorm_b', trainable=trainable, + invert_condition = True)([x, is_track]) + + x_in = x + x_emb = x + #create different embeddings for tracks and hits + if first_embed: + + x_track = Dense(16, activation='elu', name=name+'emb_xtrack',trainable=trainable)(x) + x_hit = Dense(16, activation='elu', name=name+'emb_xhit',trainable=trainable)(x) + x = MixWhere()([is_track, x_track, x_hit]) + x = ScaledGooeyBatchNorm2(name = name+'_emb_batchnorm', trainable=trainable, + )(x) + x_emb = x + + #simple gravnet + nidx,dist = KNN(K=K_nn,record_metrics=record_metrics,name=name+'_np_knn', + min_bins=20)([coords, #stop gradient here as it's given explicitly below + orig_inputs['row_splits']])#hard code it here, this is optimised given our datasets + + dist,nidx = SortAndSelectNeighbours(-1)([dist,nidx])#make it easy for edges + dist_orig = dist + dist = Sqrt()(dist)#make a stronger distance gradient for message passing + + + #each is a simple attention head + for i,n in enumerate(dmp_steps): + #this is a super lightweight 'guess' attention; coords get explicit gradient down there + + x = DistanceWeightedMessagePassing([n],name=name+'np_dmp'+str(i), + activation=dwmp_activation,#keep output in check + exp_distances = True, + trainable=trainable)([x,nidx,dist])# hops are rather light + + x = Dense(dmp_compress, activation='elu', name=name+'_pcp_x_out'+str(i),trainable=trainable)(x) + + x = ScaledGooeyBatchNorm2( + name = name+'_dmp_batchnorm'+str(i), + trainable=trainable, + )(x) + x = Concatenate()([x, x_emb])#skip connect + + + score = Dense(1, activation='sigmoid', + name=name+'_pcp_score', + trainable=trainable)(x) + + score = LLGraphCondensationScore( + name=name+'_ll_graph_condensation_score', + record_metrics = record_metrics, + K=K_loss, + penalty_fraction = 1. - reduction_target, + active=trainable, + low_energy_cut = low_energy_cut_target #allow everything below 2 GeV to be removed (other than tracks) + )([score, coords, orig_inputs['t_idx'], orig_inputs['t_energy'], rs]) + + trans_a = CreateGraphCondensation( + trainable = trainable, + reduction_target = reduction_target, + K=K_gp, + name=name+'_pcp_create', + )(score,coords,rs, #nidx,dist_orig, #do not use nidx here yet. + always_promote = is_track) + + + + x_e = Concatenate()([x,x_in])#skip + + x_e = CreateGraphCondensationEdges( + edge_dense=edge_dense, + pre_nodes=edge_pre_nodes, + K=K_gp, + trainable=trainable, + name=name+'_gc_edges')(x_e, trans_a) + + if new_format: + dist = StopGradient()(dist) + #add non orderinvariant info + x_e = Concatenate()([dist,x_e,x]) + x_e = Dense(32,name=name+'_graphn_edges_d1', + trainable=trainable, activation='elu')(x_e) + x_e = Dense(K_gp+1, name=name+'_graphn_edges_d2', + trainable=trainable, activation='softmax')(x_e)#one more for noise + + x_e = LLGraphCondensationEdges( + name=name+'_ll_graph_condensation_edges', + active=trainable, + print_loss = trainable, + record_metrics=record_metrics + )(x_e, trans_a, orig_inputs['t_idx']) + + + trans_a = InsertEdgesIntoTransition()(x_e, trans_a) + + + trans_a = MLGraphCondensationMetrics( + name = name + '_graphcondensation_metrics', + record_metrics = record_metrics, + )(trans_a, orig_inputs['t_idx'], orig_inputs['t_energy'],is_track=is_track) + + orig_inputs['t_energy'] = PlotGraphCondensationEfficiency( + plot_every = debugplots_after, + outdir= debug_outdir , + name = name + '_efficiency', + publish = publish)(orig_inputs['t_energy'], orig_inputs['t_idx'], trans_a) + + out={} + + for k in orig_inputs.keys(): + if 't_' == k[0:2]: + out[k] = SelectUp()(orig_inputs[k],trans_a) + + + out['prime_coords'] = PushUp(add_self=True)(orig_inputs['prime_coords'], + trans_a, weight = energy) + + out['rechit_energy'] = PushUp(mode='sum', add_self=True)(energy, trans_a) + out['is_track'] = SelectUp()(is_track, trans_a) + out['row_splits'] = trans_a['rs_up'] + + out['coords'] = PushUp(add_self=True)(orig_inputs['coords'], trans_a, weight = energy) + out['cond_coords'] = PushUp(add_self=True)(coords, trans_a, weight = energy) + + x_o = Concatenate()([x,x_in]) + out['down_features'] = x_o # this is for further "bouncing" + out['up_features'] = SelectUp()(x_o, trans_a) + + out['select_prime_coords'] = SelectUp()(orig_inputs['prime_coords'], trans_a) + + x_of = PushUp()(x_o, trans_a) + x_of2 = PushUp()(x_o, trans_a, weight = energy) + out['features'] = Concatenate()([out['prime_coords'],out['up_features'],x_of, x_of2]) + + out['cond_coords_down'] = coords #mostly for reference + + + + return trans_a, out + diff --git a/modules/model_tools.py b/modules/model_tools.py index f3a67ce0..47df0370 100644 --- a/modules/model_tools.py +++ b/modules/model_tools.py @@ -3,9 +3,9 @@ from Layers import RobustModel,ExtendedMetricsModel from DeepJetCore.customObjects import get_custom_objects from DeepJetCore.modeltools import apply_weights_where_possible -from DeepJetCore import DataCollection, TrainData import numpy as np + def apply_weights_from_path(path_to_weight_model, existing_model, return_weight_model=False, apply_optimizer=False): if isinstance(existing_model, (RobustModel,ExtendedMetricsModel)): diff --git a/modules/object_condensation.py b/modules/object_condensation.py index ab4d1f3b..4345fad1 100644 --- a/modules/object_condensation.py +++ b/modules/object_condensation.py @@ -160,11 +160,9 @@ def att_func(self,dsq_k_m): def V_att_k(self): ''' ''' - x_k_e = tf.expand_dims(self.x_k,axis=1) - N_k = tf.reduce_sum(self.mask_k_m, axis=1) - dsq_k_m = tf.reduce_sum((self.x_k_m - x_k_e)**2, axis=-1, keepdims=True) #K x V-obj x 1 + dsq_k_m = self.calc_dsq_att() #K x V-obj x 1 sigma = self.weighted_d_k_m(dsq_k_m) #create gradients for all @@ -184,6 +182,16 @@ def rep_func(self,dsq_k_v): def weighted_d_k_m(self, dsq): # dsq K x V x 1 return tf.expand_dims(self.d_k, axis=1) # K x 1 x 1 + + def calc_dsq_att(self): + x_k_e = tf.expand_dims(self.x_k,axis=1) + dsq_k_m = tf.reduce_sum((self.x_k_m - x_k_e)**2, axis=-1, keepdims=True) #K x V-obj x 1 + return dsq_k_m + + def calc_dsq_rep(self): + dsq = tf.expand_dims(self.x_k, axis=1) - tf.expand_dims(self.x_v, axis=0) #K x V x C + dsq = tf.reduce_sum(dsq**2, axis=-1, keepdims=True) #K x V x 1 + return dsq def V_rep_k(self): @@ -191,8 +199,7 @@ def V_rep_k(self): N_k = tf.reduce_sum(self.Mnot, axis=1) #future remark: if this gets too large, one could use a kNN here - dsq = tf.expand_dims(self.x_k, axis=1) - tf.expand_dims(self.x_v, axis=0) #K x V x C - dsq = tf.reduce_sum(dsq**2, axis=-1, keepdims=True) #K x V x 1 + dsq = self.calc_dsq_rep() # nogradbeta = tf.stop_gradient(self.beta_k_m) #weight. tf.reduce_sum( tf.exp(-dsq) * d_v_e, , axis=1) / tf.reduce_sum( tf.exp(-dsq) ) @@ -290,7 +297,27 @@ def add_to_terms(self, return V_att, V_rep, Noise_pen, B_pen, pll, high_B_pen +class Hinge_OC_per_sample(Basic_OC_per_sample): + ''' + This is the classic repulsive hinge loss + ''' + def __init__(self, **kwargs): + super(Hinge_OC_per_sample, self).__init__(**kwargs) + + def rep_func(self,dsq_k_v): + return tf.nn.relu(1. - tf.sqrt(dsq_k_v + 1e-6)) + +class Hinge_Manhatten_OC_per_sample(Hinge_OC_per_sample): + def calc_dsq_att(self): + x_k_e = tf.expand_dims(self.x_k,axis=1) + dsq_k_m = tf.reduce_sum(tf.abs(self.x_k_m - x_k_e), axis=-1, keepdims=True) #K x V-obj x 1 + return dsq_k_m**2 #still square it since that's what the function should return + + def calc_dsq_rep(self): + dsq = tf.expand_dims(self.x_k, axis=1) - tf.expand_dims(self.x_v, axis=0) #K x V x C + dsq = tf.reduce_sum(tf.abs(dsq), axis=-1, keepdims=True) #K x V x 1 + return dsq**2 #still square it since that's what the function should return class PushPull_OC_per_sample(Basic_OC_per_sample): diff --git a/modules/oc_helper_ops.py b/modules/oc_helper_ops.py index fa3b776a..d9685614 100644 --- a/modules/oc_helper_ops.py +++ b/modules/oc_helper_ops.py @@ -74,13 +74,15 @@ def _CreateMidxGrad(op, sel_dxs, m_not): return None -def SelectWithDefault(indices, tensor, default=0): +def SelectWithDefault(indices, tensor, default=0, no_check=False): expidxs = tf.expand_dims(indices,axis=2) tfidxs = tf.where(expidxs<0,0,expidxs) gtens = tf.gather_nd(tensor,tfidxs) out = tf.where(expidxs<0, default, gtens) + if no_check: + return out #check if the size ends up as we might want with tf.control_dependencies([ tf.assert_equal(tf.shape(tf.shape(out)), tf.shape(tf.shape(indices)) + 1), diff --git a/modules/push_knn_op.py b/modules/push_knn_op.py index edc242df..9a50d546 100755 --- a/modules/push_knn_op.py +++ b/modules/push_knn_op.py @@ -10,8 +10,10 @@ def PushKnn(w,f,nidx): ''' Pushes features (summing them) to the neighbours. - This assumes that if there is a self neighbour index, that is actually means something! + This assumes that if there is a self neighbour index, that actually means something! This op is compatible with '-1' indices as synonym for no neighbour + + This op does not assume that an index of '-1' truncates the list ''' #go through the columns diff --git a/scripts/inspect_batchnorm.py b/scripts/inspect_batchnorm.py new file mode 100755 index 00000000..e74784e3 --- /dev/null +++ b/scripts/inspect_batchnorm.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 + +from argparse import ArgumentParser +parser = ArgumentParser('Inspect all batchnorm layers of a model') +parser.add_argument('inputModel') +args = parser.parse_args() + +from DeepJetCore.modeltools import load_model +from Layers import ScaledGooeyBatchNorm2 + +m = load_model(args.inputModel) + +for l in m.layers: + if isinstance(l, ScaledGooeyBatchNorm2): + ws = l.weights + print('\n\nLayer',l.name,':\n') + for w in ws: + print(w.name,w.shape,'\n',w.numpy()) + diff --git a/setup.sh b/setup.sh index b2a97b28..89521c91 100755 --- a/setup.sh +++ b/setup.sh @@ -2,6 +2,6 @@ cd $HGCALML/modules cd compiled -make -j +make -j4 cd $HGCALML git submodule update --init --recursive