Skip to content

Commit

Permalink
feat: add ml training within Django admin
Browse files Browse the repository at this point in the history
  • Loading branch information
bartjkdp committed Oct 22, 2024
1 parent a33f904 commit a1286bd
Show file tree
Hide file tree
Showing 34 changed files with 840 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ docs/_build/
target/
Python.gitignore
venv/
.venv

# Notepad++ backups #
*.bak
Expand Down
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ RUN set -eux; \
rm -rf /var/lib/apt/lists/*

COPY app/requirements /app/requirements
COPY app/signals/apps/classification/requirements.txt /app/signals/apps/classification/requirements.txt

RUN set -eux; \
pip install --no-cache -r /app/requirements/requirements.txt; \
pip install --no-cache -r /app/signals/apps/classification/requirements.txt; \
pip install --no-cache tox; \
chgrp signals /app; \
chmod g+w /app; \
Expand Down
2 changes: 1 addition & 1 deletion app/requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ click-repl==0.3.0
# via celery
cron-descriptor==1.4.3
# via django-celery-beat
cryptography==43.0.1
cryptography==43.0.0
# via
# azure-storage-blob
# josepy
Expand Down
2 changes: 1 addition & 1 deletion app/requirements/requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ cron-descriptor==1.4.3
# via
# -r requirements_test.txt
# django-celery-beat
cryptography==43.0.1
cryptography==43.0.0
# via
# -r requirements_test.txt
# azure-storage-blob
Expand Down
2 changes: 1 addition & 1 deletion app/requirements/requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ cron-descriptor==1.4.3
# via
# -r requirements.txt
# django-celery-beat
cryptography==43.0.1
cryptography==43.0.0
# via
# -r requirements.txt
# azure-storage-blob
Expand Down
5 changes: 3 additions & 2 deletions app/signals/apps/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@
# Status message search
re_path(r'v1/private/status-messages/search/?$', StatusMessageSearchView.as_view(), name='status-message-search'),

# Legacy prediction proxy endpoint, still needed
path('category/prediction', LegacyMlPredictCategoryView.as_view(), name='ml-tool-predict-proxy'),
# # Legacy prediction proxy endpoint, still needed
# path('category/prediction', LegacyMlPredictCategoryView.as_view(), name='ml-tool-predict-proxy'),
path('', include('signals.apps.classification.urls')),

# The base routes of the API
path('v1/', include(base_router.urls)),
Expand Down
Empty file.
9 changes: 9 additions & 0 deletions app/signals/apps/classification/admin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from django.contrib import admin

from signals.apps.classification.admin.admins import TrainingSetAdmin, ClassifierAdmin
from signals.apps.classification.models import TrainingSet
from signals.apps.classification.models.classifier import Classifier

admin.site.register(TrainingSet, TrainingSetAdmin)
admin.site.register(Classifier, ClassifierAdmin)

113 changes: 113 additions & 0 deletions app/signals/apps/classification/admin/admins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import this

from django.contrib import admin, messages

from signals.apps.classification.models import Classifier
from signals.apps.classification.tasks import train_classifier
import openpyxl


class TrainingSetAdmin(admin.ModelAdmin):
list_display = ('name', 'file', )
actions = ["run_training_with_training_set"]

@admin.action(description="Train model met geselecteerde dataset")
def run_training_with_training_set(self, request, queryset):
"""
Run validation, if validation fails show an error message.
First we validate if there are no missing columns (Main, Sub and Text column are required), after this we check if there is atleast one row of data (next
to the headers)
"""
for training_set in queryset:
file = training_set.file

wb = openpyxl.load_workbook(file)
first_sheet = wb.active

headers = [cell.value for cell in first_sheet[1]]
required_columns = ["Main", "Sub", "Text"]
missing_columns = [col for col in required_columns if col not in headers]

if missing_columns:
self.message_user(
request,
f"Training set { training_set.name } is missing required columns: {', '.join(missing_columns)}",
messages.ERROR,
)

return

data_rows = list(first_sheet.iter_rows(min_row=2, values_only=True))
if not any(data_rows):
self.message_user(
request,
f"The training set { training_set.name } does not contain any data rows.",
messages.ERROR
)
return

train_classifier.delay(training_set.id)

self.message_user(
request,
"Training of the model has been initiated.",
messages.SUCCESS,
)


class ClassifierAdmin(admin.ModelAdmin):
"""
Creating or disabling classifiers by hand in the Admin interface is disabled,
a successful training job should create his own classifier object.
"""
list_display = ('name', 'precision', 'recall', 'accuracy', 'is_active', )
actions = ["activate_classifier"]
readonly_fields = ('training_status', 'training_error', )

@admin.action(description="Maak deze classifier actief")
def activate_classifier(self, request, queryset):
"""
Make chosen classifier active, disable other classifiers
"""

if queryset.count() > 1:
self.message_user(
request,
"You can only make one classifier active.",
messages.ERROR
)
return

try:
Classifier.objects.update(is_active=False)
Classifier.objects.filter(id=queryset.first().id).update(is_active=True)

self.message_user(
request,
f"Classifier { queryset.first().name } has been activated.",
messages.SUCCESS
)
except Exception:
self.message_user(
request,
f"Classifier { queryset.first().name } has not been activated.",
messages.ERROR
)



def get_readonly_fields(self, request, obj=None):
if obj:
return [f.name for f in self.model._meta.fields]
return []

def has_add_permission(self, request):
return False

def has_change_permission(self, request, obj=None):
return False

def has_delete_permission(self, request, obj=None):
return True
6 changes: 6 additions & 0 deletions app/signals/apps/classification/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from django.apps import AppConfig


class ClassificationConfig(AppConfig):
name = 'signals.apps.classification'
verbose_name = 'Classificatie management'
Empty file.
Empty file.
19 changes: 19 additions & 0 deletions app/signals/apps/classification/management/commands/train-ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from django.core.management.base import BaseCommand, CommandError

from signals.apps.classification.models import TrainingSet
from signals.apps.classification.tasks import train_classifier

class Command(BaseCommand):
help = "Train specific model"

def add_arguments(self, parser):
parser.add_argument("training_set_id", type=int)

def handle(self, *args, **options):
try:
training_set = TrainingSet.objects.get(pk=options["training_set_id"])
except TrainingSet.DoesNotExist:
raise CommandError('Training Set "%s" does not exist' % options["training_set_id"])

train_classifier(training_set.id)

42 changes: 42 additions & 0 deletions app/signals/apps/classification/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Generated by Django 4.2.11 on 2024-09-17 09:53

from django.db import migrations, models

import signals.apps.services.domain.checker_factories
import signals.apps.services.domain.mimetypes
import signals.apps.services.validator.file
import signals.apps.signals.models.utils


class Migration(migrations.Migration):

initial = True

dependencies = [
]

operations = [
migrations.CreateModel(
name='TrainingSet',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('created_at', models.DateTimeField(auto_now_add=True)),
('name', models.CharField(max_length=255)),
('file', models.FileField(max_length=255, upload_to='training_sets/%Y/%m/%d/', validators=[
signals.apps.services.validator.file.MimeTypeAllowedValidator(
signals.apps.services.domain.mimetypes.MimeTypeFromContentResolverFactory(),
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
),
signals.apps.services.validator.file.MimeTypeIntegrityValidator(
signals.apps.services.domain.mimetypes.MimeTypeFromContentResolverFactory(),
signals.apps.services.domain.mimetypes.MimeTypeFromFilenameResolverFactory()
),
signals.apps.services.validator.file.ContentIntegrityValidator(
signals.apps.services.domain.mimetypes.MimeTypeFromContentResolverFactory(),
signals.apps.services.domain.checker_factories.ContentCheckerFactory()
),
signals.apps.services.validator.file.FileSizeValidator(20971520)
])),
],
),
]
23 changes: 23 additions & 0 deletions app/signals/apps/classification/migrations/0002_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated by Django 4.2.11 on 2024-09-17 11:09

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('classification', '0001_initial'),
]

operations = [
migrations.CreateModel(
name='Classifier',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('created_at', models.DateTimeField(auto_now_add=True)),
('name', models.CharField(max_length=255)),
('middle_model', models.FileField(max_length=255, upload_to='classification_models/middle/%Y/%m/%d/')),
('middle_sub_model', models.FileField(max_length=255, upload_to='classification_models/middle_sub/%Y/%m/%d/')),
],
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Generated by Django 4.2.11 on 2024-09-17 11:58

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('classification', '0002_classifier'),
]

operations = [
migrations.AddField(
model_name='classifier',
name='accuracy',
field=models.FloatField(default=0),
),
migrations.AddField(
model_name='classifier',
name='precision',
field=models.FloatField(default=0),
),
migrations.AddField(
model_name='classifier',
name='recall',
field=models.FloatField(default=0),
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Generated by Django 4.2.11 on 2024-09-19 08:07

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('classification', '0003_classifier_accuracy_classifier_precision_and_more'),
]

operations = [
migrations.RemoveField(
model_name='classifier',
name='middle_model',
),
migrations.RemoveField(
model_name='classifier',
name='middle_sub_model',
),
migrations.AddField(
model_name='classifier',
name='main_model',
field=models.FileField(blank=True, max_length=255, null=True, upload_to='classification_models/middle/%Y/%m/%d/'),
),
migrations.AddField(
model_name='classifier',
name='sub_model',
field=models.FileField(blank=True, max_length=255, null=True, upload_to='classification_models/middle_sub/%Y/%m/%d/'),
),
migrations.AlterField(
model_name='classifier',
name='accuracy',
field=models.FloatField(blank=True, default=0, null=True),
),
migrations.AlterField(
model_name='classifier',
name='precision',
field=models.FloatField(blank=True, default=0, null=True),
),
migrations.AlterField(
model_name='classifier',
name='recall',
field=models.FloatField(blank=True, default=0, null=True),
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Generated by Django 4.2.15 on 2024-09-27 10:25

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('classification', '0004_remove_classifier_middle_model_and_more'),
]

operations = [
migrations.AddField(
model_name='classifier',
name='is_active',
field=models.BooleanField(default=False),
),
migrations.AlterField(
model_name='classifier',
name='main_model',
field=models.FileField(blank=True, max_length=255, null=True, upload_to='classification_models/main/%Y/%m/%d/'),
),
migrations.AlterField(
model_name='classifier',
name='sub_model',
field=models.FileField(blank=True, max_length=255, null=True, upload_to='classification_models/main_sub/%Y/%m/%d/'),
),
]
Loading

0 comments on commit a1286bd

Please sign in to comment.