Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pointwise features generation #2

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion add_agent.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ gpu_name=${6:-"gypsum-1080ti"} # "gypsum-1080ti"
for ((i = 1; i <= ${n_agents}; i++)); do
JOB_DESC=${model}_${dataset}_sweep${seed}-${i} && JOB_NAME=${JOB_DESC}_$(date +%s) && \
sbatch -J ${JOB_NAME} -e jobs/${JOB_NAME}.err -o jobs/${JOB_NAME}.log \
--partition=${gpu_name} --gres=gpu:1 --mem=80G --time=12:00:00 \
--partition=${gpu_name} --gres=gpu:1 --mem=120G --time=12:00:00 \
run_sbatch.sh e2e_scripts/train.py \
--dataset="${dataset}" \
--dataset_random_seed=${seed} \
Expand Down
147 changes: 147 additions & 0 deletions e2e_debug/solve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import json
import argparse
import cvxpy as cp
import logging
import numpy as np
import torch

from IPython import embed

from e2e_pipeline.hac_cut_layer import HACCutLayer

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)


class Parser(argparse.ArgumentParser):
def __init__(self):
super().__init__()
self.add_argument(
"--data_fpath", type=str
)
self.add_argument(
"--data_idx", type=int, default=0
)
self.add_argument(
"--scs_max_sdp_iters", type=int, default=50000
)
self.add_argument(
"--scs_silent", action="store_true",
)
self.add_argument(
"--scs_eps", type=float, default=1e-3
)
self.add_argument(
"--scs_scale", type=float, default=1e-1,
)
self.add_argument(
"--scs_dont_normalize", action="store_true",
)
self.add_argument(
"--scs_use_indirect", action="store_true",
)
self.add_argument(
"--scs_dont_use_quad_obj", action="store_true",
)
self.add_argument(
"--scs_alpha", type=float, default=1.5
)
self.add_argument(
"--scs_log_csv_filename", type=str,
)
self.add_argument(
"--max_scaling", action="store_true",
)
self.add_argument(
"--interactive", action="store_true",
)


if __name__ == '__main__':
parser = Parser()
args = parser.parse_args()
logger.info("Script arguments:")
logger.info(args.__dict__)

# Read error file
logger.info("Reading input data")
if args.data_fpath.endswith('.pt'):
_W_val = torch.load(args.data_fpath, map_location='cpu').numpy()
else:
with open(args.data_fpath, 'r') as fh:
data = json.load(fh)
assert len(data['errors']) > 0
# Pick specific error instance to process
error_data = data['errors'][args.data_idx]

# Extract input data from the error instance
_raw = np.array(error_data['model_call_args']['data'])
_W_val = np.array(error_data['cvxpy_layer_args']['W_val'])

# Construct cvxpy problem
logger.info('Constructing optimization problem')
# edge_weights = _W_val.tocoo()
n = _W_val.shape[0]
W = _W_val
# W = csr_matrix((edge_weights.data, (edge_weights.row, edge_weights.col)), shape=(n, n))
X = cp.Variable((n, n), PSD=True)
# Build constraint set
constraints = [
cp.diag(X) == np.ones((n,)),
X[:n, :] >= 0,
# X[:n, :] <= 1
]

# Setup HAC Cut
hac_cut = HACCutLayer()
hac_cut.eval()

sdp_obj_value = float('inf')
result_idxs, results_X, results_clustering = [], [], []
no_solution_scaling_factors = []
for i in range(0, 10): # n
# Skipping 1; no scaling leads to non-convergence (infinite objective value)
if i == 0:
scaling_factor = np.max(np.abs(W))
else:
scaling_factor = i
if args.max_scaling:
continue
logger.info(f'Scaling factor={scaling_factor}')
# Create problem
W_scaled = W / scaling_factor
problem = cp.Problem(cp.Maximize(cp.trace(W_scaled @ X)), constraints)
# Solve problem
sdp_obj_value = problem.solve(
solver=cp.SCS,
verbose=not args.scs_silent,
max_iters=args.scs_max_sdp_iters,
eps=args.scs_eps,
normalize=not args.scs_dont_normalize,
alpha=args.scs_alpha,
scale=args.scs_scale,
use_indirect=args.scs_use_indirect,
# use_quad_obj=not args.scs_dont_use_quad_obj
)
logger.info(f"@scaling={scaling_factor}, objective value = {sdp_obj_value}, norm={np.linalg.norm(W_scaled)}")
if sdp_obj_value != float('inf'):
result_idxs.append(i)
results_X.append(X.value)
# Find clustering solution
hac_cut.get_rounded_solution(torch.tensor(X.value), torch.tensor(W_scaled))
results_clustering.append(hac_cut.cluster_labels.numpy())
else:
no_solution_scaling_factors.append(scaling_factor)
logger.info(f"Solution not found = {len(no_solution_scaling_factors)}")
logger.info(f"Solution found = {len(results_X)}")

# logger.info("Same clustering:")
# for i in range(len(results_clustering)-1):
# logger.info(np.array_equal(results_clustering[i], results_clustering[i + 1]))
# logger.info(f"Solution found with scaling factor = {scaling_factor}")
# if args.interactive and sdp_obj_value == float('inf'):
# embed()

if args.interactive:
embed()
7 changes: 4 additions & 3 deletions e2e_pipeline/cc_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,20 @@ class CCInference(torch.nn.Module):
Correlation clustering inference-only model. Expects edge weights and the number of nodes as input.
"""

def __init__(self, sdp_max_iters, sdp_eps):
def __init__(self, sdp_max_iters, sdp_eps, sdp_scale, use_sdp):
super().__init__()
self.uncompress_layer = UncompressTransformLayer()
self.sdp_layer = SDPLayer(max_iters=sdp_max_iters, eps=sdp_eps)
self.sdp_layer = SDPLayer(max_iters=sdp_max_iters, eps=sdp_eps, scale_input=sdp_scale)
self.hac_cut_layer = HACCutLayer()
self.use_sdp = use_sdp

def forward(self, edge_weights, N, min_id=0, threshold=None, verbose=False):
edge_weights = torch.squeeze(edge_weights)
if threshold is not None:
# threshold is used to convert a similarity score (in [0,1]) into edge weights (in R, i.e. + and -)
edge_weights = torch.sigmoid(edge_weights) - threshold
edge_weights_uncompressed = self.uncompress_layer(edge_weights, N)
output_probs = self.sdp_layer(edge_weights_uncompressed, N)
output_probs = self.sdp_layer(edge_weights_uncompressed, N, use_sdp=self.use_sdp)
pred_clustering = self.hac_cut_layer(output_probs, edge_weights_uncompressed)

if verbose:
Expand Down
22 changes: 15 additions & 7 deletions e2e_pipeline/hac_cut_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self):
Takes fractional SDP output as input, and simultaneously builds & cuts avg. HAC tree to get rounded solution.
Executes straight-through estimator as the backward pass.
"""
def get_rounded_solution(self, X, weights, _MAX_DIST=10, use_similarities=True, max_similarity=1, verbose=False):
def get_rounded_solution(self, X, weights, _MAX_DIST=1000, use_similarities=True, max_similarity=1, verbose=False):
"""
X is a symmetric NxN matrix of fractional, decision values with a 1-diagonal (output from the SDP layer)
weights is an NxN upper-triangular (shift 1) matrix of edge weights
Expand All @@ -34,7 +34,8 @@ def get_rounded_solution(self, X, weights, _MAX_DIST=10, use_similarities=True,
round_matrix = torch.eye(D, device=device)

# Take the upper triangular and mask the other values with a large number
Y = _MAX_DIST * torch.ones(D, D, device=device).tril() + (max_similarity-X if use_similarities else X).triu(1)
_MAX_DIST = torch.max(torch.abs(X)) * _MAX_DIST
Y = _MAX_DIST * torch.ones(D, D, device=device).tril() + (max_similarity - X if use_similarities else X).triu(1)
# Compute the dissimilarity minima per row
values, indices = torch.min(Y, dim=1)

Expand Down Expand Up @@ -100,7 +101,7 @@ def get_rounded_solution(self, X, weights, _MAX_DIST=10, use_similarities=True,
# Energy calculations
clustering[max_node] = clustering[parent_1] + clustering[parent_2]
leaf_indices = torch.where(clustering[max_node])[0]
leaf_edges = torch.meshgrid(leaf_indices, leaf_indices)
leaf_edges = torch.meshgrid(leaf_indices, leaf_indices, indexing='ij')
energy[max_node] = energy[parent_1] + energy[parent_2]
merge_energy = torch.sum(weights[leaf_edges])
if merge_energy >= energy[max_node]:
Expand All @@ -123,9 +124,16 @@ def get_rounded_solution(self, X, weights, _MAX_DIST=10, use_similarities=True,
self.round_matrix = round_matrix
self.cluster_labels = clustering[-1]
self.parents = parents
objective_matrix = weights * torch.triu(round_matrix, diagonal=1)
self.objective_value = (energy[max_node] - torch.sum(objective_matrix[objective_matrix < 0])).item() # MA
with torch.no_grad():
objective_matrix = weights * torch.triu(round_matrix, diagonal=1)
self.objective_value = (energy[max_node] - torch.sum(objective_matrix[objective_matrix < 0])).item() # MA
return self.round_matrix

def forward(self, X, W, use_similarities=True):
return X + (self.get_rounded_solution(X, W, use_similarities=use_similarities) - X).detach()
def forward(self, X, W, use_similarities=True, return_triu=False):
solution = X + (self.get_rounded_solution(X, W,
use_similarities=use_similarities,
max_similarity=torch.max(X)) - X).detach()
if return_triu:
triu_indices = torch.triu_indices(len(solution), len(solution), offset=1)
return solution[triu_indices[0], triu_indices[1]]
return solution
6 changes: 3 additions & 3 deletions e2e_pipeline/hac_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def tune_threshold(self, model, dataloader, device, n_trials=1000):
all_gold = []
blockwise_trees = []
all_dists = []
max_pred_id = -1 # In each iteration, add to all blockwise predicted IDs to distinguish from previous blocks
n_features = dataloader.dataset[0][0].shape[1]
for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Tuning threshold on dev')):
data, _, cluster_ids = batch
Expand All @@ -46,7 +45,8 @@ def tune_threshold(self, model, dataloader, device, n_trials=1000):

# Forward pass through the e2e model
data = data.to(device)
tree_and_alts, dists = self.cluster(model(data), block_size, return_tree=True)
edge_weights = model(data, N=len(cluster_ids), warmstart=True)
tree_and_alts, dists = self.cluster(edge_weights, block_size, return_tree=True)
blockwise_trees.append(tree_and_alts)
all_dists.append(dists)

Expand All @@ -61,7 +61,7 @@ def tune_threshold(self, model, dataloader, device, n_trials=1000):
best_dev_metric = -1
for _thresh in tqdm(thresholds, desc="Finding best cut threshold"):
all_pred = []
max_pred_id = -1
max_pred_id = -1 # In each iter, add to all blockwise predicted IDs to distinguish from previous blocks
for (_hac, _hac_alts) in blockwise_trees:
_cut_labels = self.cut_tree(_hac, _hac_alts, _thresh)
pred_cluster_ids = _cut_labels + (max_pred_id + 1)
Expand Down
13 changes: 9 additions & 4 deletions e2e_pipeline/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,21 @@
class EntResModel(torch.nn.Module):
def __init__(self, n_features, neumiss_depth, dropout_p, dropout_only_once, add_neumiss,
neumiss_deq, hidden_dim, n_hidden_layers, add_batchnorm, activation,
negative_slope, hidden_config, sdp_max_iters, sdp_eps, use_rounded_loss=True):
negative_slope, hidden_config, sdp_max_iters, sdp_eps, sdp_scale=False, use_rounded_loss=True,
return_triu_on_train=False, use_sdp=True):
super().__init__()
# Layers
self.mlp_layer = MLPLayer(n_features=n_features, neumiss_depth=neumiss_depth, dropout_p=dropout_p,
dropout_only_once=dropout_only_once, add_neumiss=add_neumiss, neumiss_deq=neumiss_deq,
hidden_dim=hidden_dim, n_hidden_layers=n_hidden_layers, add_batchnorm=add_batchnorm,
activation=activation, negative_slope=negative_slope, hidden_config=hidden_config)
self.uncompress_layer = UncompressTransformLayer()
self.sdp_layer = SDPLayer(max_iters=sdp_max_iters, eps=sdp_eps)
self.sdp_layer = SDPLayer(max_iters=sdp_max_iters, eps=sdp_eps, scale_input=sdp_scale)
self.hac_cut_layer = HACCutLayer()
# Configs
self.use_rounded_loss = use_rounded_loss
self.return_triu_on_train = return_triu_on_train
self.use_sdp = use_sdp

def forward(self, x, N, warmstart=False, verbose=False):
edge_weights = torch.squeeze(self.mlp_layer(x))
Expand All @@ -41,14 +44,16 @@ def forward(self, x, N, warmstart=False, verbose=False):
logger.info(f"Size of W_matrix = {edge_weights_uncompressed.size()}")
logger.info(f"\n{edge_weights_uncompressed}")

output_probs = self.sdp_layer(edge_weights_uncompressed, N)
output_probs = self.sdp_layer(edge_weights_uncompressed, N, use_sdp=self.use_sdp, return_triu=(
self.training and not self.use_rounded_loss and self.return_triu_on_train))
if verbose:
logger.info(f"Size of X = {output_probs.size()}")
logger.info(f"\n{output_probs}")
if self.training and not self.use_rounded_loss:
return output_probs

pred_clustering = self.hac_cut_layer(output_probs, edge_weights_uncompressed)
pred_clustering = self.hac_cut_layer(output_probs, edge_weights_uncompressed,
return_triu=(self.training and self.return_triu_on_train))
if verbose:
logger.info(f"Size of X_r = {pred_clustering.size()}")
logger.info(f"\n{pred_clustering}")
Expand Down
Loading