diff --git a/store/neurostore/resources/base.py b/store/neurostore/resources/base.py index a1557fc4..dc18f441 100644 --- a/store/neurostore/resources/base.py +++ b/store/neurostore/resources/base.py @@ -69,6 +69,10 @@ class BaseView(MethodView): _view_fields = {} # _default_exclude = None + @classmethod + def check_duplicate(cls, data, record): + return False + def get_affected_ids(self, ids): """ Get all the ids that are affected by a change to a record.. @@ -223,16 +227,16 @@ def load_nested_records(cls, data, record=None): @classmethod def update_or_create(cls, data, id=None, user=None, record=None, flush=True): """ - scenerios: - 1. cloning a study - a. clone everything, a study is an object - 2. cloning a studyset - a. studies are linked to a studyset, so create a new studyset with same links - 3. cloning an annotation - a. annotations are linked to studysets, update when studyset updates - 4. creating an analysis + Scenarios: + 1. Cloning a study + a. Clone everything, a study is an object + 2. Cloning a studyset + a. Studies are linked to a studyset, so create a new studyset with same links + 3. Cloning an annotation + a. Annotations are linked to studysets, update when studyset updates + 4. Creating an analysis a. I should have to own all (relevant) parent objects - 5. creating an annotation + 5. Creating an annotation a. I should not have to own the studyset to create an annotation """ @@ -291,6 +295,14 @@ def update_or_create(cls, data, id=None, user=None, record=None, flush=True): return record + data["user_id"] = current_user.external_id + if hasattr(record, "id"): + data["id"] = record.id + # check to see if duplicate + duplicate = cls.check_duplicate(data, record) + if duplicate: + return duplicate + # Update all non-nested attributes for k, v in data.items(): if k in cls._parent and v is not None: diff --git a/store/neurostore/resources/data.py b/store/neurostore/resources/data.py index a3df78f7..888178cd 100644 --- a/store/neurostore/resources/data.py +++ b/store/neurostore/resources/data.py @@ -931,6 +931,57 @@ def join_tables(self, q, args): ) return super().join_tables(q, args) + @classmethod + def check_duplicate(cls, data, record): + study_id = data.get("study_id") + + if hasattr(record, "id") and record.id and record.id == data.get("id"): + # not a duplicate, same record + return False + + if hasattr(record, "study") and record.study: + study = record.study + else: + study = Study.query.filter_by(id=study_id).first() + + if not study: + return False + + name = data.get("name") + user_id = data.get("user_id") + coordinates = data.get("points") + + for analysis in study.analyses: + if ( + analysis.name == name + and analysis.user_id == user_id + and cls._compare_coordinates(analysis.points, coordinates) + ): + return analysis + + return False + + @staticmethod + def _compare_coordinates(existing_points, new_points): + # Create a dictionary to map point IDs to their coordinates + existing_points_dict = { + point.id: (point.x, point.y, point.z) for point in existing_points + } + + # Create sets for comparison + existing_points_set = {(point.x, point.y, point.z) for point in existing_points} + new_points_set = set() + + for point in new_points: + if "x" in point and "y" in point and "z" in point: + new_points_set.add((point["x"], point["y"], point["z"])) + elif "id" in point and point["id"] in existing_points_dict: + new_points_set.add(existing_points_dict[point["id"]]) + else: + return False # If the point doesn't have coordinates or a valid ID, return False + + return existing_points_set == new_points_set + @view_maker class ConditionsView(ObjectView, ListView): diff --git a/store/neurostore/tests/api/test_analyses.py b/store/neurostore/tests/api/test_analyses.py index bfad2c87..ca0dae6e 100644 --- a/store/neurostore/tests/api/test_analyses.py +++ b/store/neurostore/tests/api/test_analyses.py @@ -151,3 +151,34 @@ def test_post_analysis_without_order(auth_client, ingest_neurosynth, session): # Check if the 'order' field is not None assert resp.json()["order"] is not None + + +def test_create_duplicate_analysis(auth_client, ingest_neurosynth, session): + # Get an existing analysis from the database + analysis_db = Analysis.query.first() + analysis = AnalysisSchema().dump(analysis_db) + id_ = auth_client.username + user = User.query.filter_by(external_id=id_).first() + analysis_db.study.user = user + for a in analysis_db.study.analyses: + a.user = user + session.add(a) + session.add(analysis_db.study) + session.commit() + + # Remove fields that are auto-generated + for k in ["user", "id", "created_at", "updated_at", "entities"]: + analysis.pop(k, None) + + # Create the first analysis + resp = auth_client.post("/api/analyses/", data=analysis) + assert resp.status_code == 200 + + # Attempt to create a duplicate analysis + resp_duplicate = auth_client.post("/api/analyses/", data=analysis) + assert resp_duplicate.status_code == 200 + + # Check if the duplicate analysis is the same as the original + original_analysis = resp.json() + duplicate_analysis = resp_duplicate.json() + assert original_analysis["id"] == duplicate_analysis["id"]