Skip to content

Commit

Permalink
Aggregation test on CIFAR-10
Browse files Browse the repository at this point in the history
Signed-off-by: Emanuele Ballarin <[email protected]>
  • Loading branch information
emaballarin committed Dec 11, 2024
1 parent ae263b4 commit 2d7acab
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
47 changes: 47 additions & 0 deletions carsoeval_a_agg.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/bash -li
#SBATCH --job-name=eval_carso_sc_a_agg
#SBATCH --mail-type=FAIL,END
#SBATCH --partition=DGX
#SBATCH --time=0-12:00:00
#SBATCH --nodes=1 # Nodes
#SBATCH --ntasks-per-node=7 # GPUs per node
#SBATCH --cpus-per-task=4 # Cores per node / GPUs per node
#SBATCH --mem=96G # 4 * Cores per node
#SBATCH --gres=gpu:7 # GPUs per node
################################################################################
#
sleep 3
#
#source $HOME/.bashrc
#
export CODEHOME="$HOME/Downloads/"
export MYPYTHON="$HOME/micromamba/envs/nightorch/bin/python"
#
export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE))
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
#
echo " "
echo "hostname="$(hostname)
echo "WORLD_SIZE="$WORLD_SIZE
echo "OMP_NUM_THREADS="$OMP_NUM_THREADS
echo "MASTER_ADDR="$MASTER_ADDR
echo "MASTER_PORT="$MASTER_PORT
echo " "
#
################################################################################
cd "$CODEHOME/CARSO/src/"
#
echo "-----------------------------------------------------------------------------------------------------------------"
echo " "
echo "START TIME "$(date +'%Y_%m_%d-%H_%M_%S')
echo " "
echo "-----------------------------------------------------------------------------------------------------------------"
srun "$MYPYTHON" -O "$CODEHOME/CARSO/src/eval_a.py" --dist --e2e --batchsize 70 --agg "logit"
echo "-----------------------------------------------------------------------------------------------------------------"
echo " "
echo "STOP TIME "$(date +'%Y_%m_%d-%H_%M_%S')
echo " "
echo "-----------------------------------------------------------------------------------------------------------------"
#
8 changes: 8 additions & 0 deletions src/eval_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def main_parse() -> argparse.Namespace:
metavar="<n_samples>",
help="Number of sampled recosntructions to classify (default: 8)",
)
parser.add_argument(
"--agg",
type=str,
default="peel",
metavar="<aggregation_method>",
help="Aggregation method for model outputs (default: PeeL)",
)
return parser.parse_args()


Expand Down Expand Up @@ -196,6 +203,7 @@ def main_run(args: argparse.Namespace) -> None:
joint_latent_dim=JOINT_LATENT_DIM,
ensemble_size=args.nsamples,
differentiable_infer=args.e2e,
agg_method=args.agg,
)

# noinspection DuplicatedCode
Expand Down

0 comments on commit 2d7acab

Please sign in to comment.