forked from Amsterdam/signals
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add ml training within Django admin
- Loading branch information
Showing
34 changed files
with
840 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,7 @@ docs/_build/ | |
target/ | ||
Python.gitignore | ||
venv/ | ||
.venv | ||
|
||
# Notepad++ backups # | ||
*.bak | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from django.contrib import admin | ||
|
||
from signals.apps.classification.admin.admins import TrainingSetAdmin, ClassifierAdmin | ||
from signals.apps.classification.models import TrainingSet | ||
from signals.apps.classification.models.classifier import Classifier | ||
|
||
admin.site.register(TrainingSet, TrainingSetAdmin) | ||
admin.site.register(Classifier, ClassifierAdmin) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import this | ||
|
||
from django.contrib import admin, messages | ||
|
||
from signals.apps.classification.models import Classifier | ||
from signals.apps.classification.tasks import train_classifier | ||
import openpyxl | ||
|
||
|
||
class TrainingSetAdmin(admin.ModelAdmin): | ||
list_display = ('name', 'file', ) | ||
actions = ["run_training_with_training_set"] | ||
|
||
@admin.action(description="Train model met geselecteerde dataset") | ||
def run_training_with_training_set(self, request, queryset): | ||
""" | ||
Run validation, if validation fails show an error message. | ||
First we validate if there are no missing columns (Main, Sub and Text column are required), after this we check if there is atleast one row of data (next | ||
to the headers) | ||
""" | ||
for training_set in queryset: | ||
file = training_set.file | ||
|
||
wb = openpyxl.load_workbook(file) | ||
first_sheet = wb.active | ||
|
||
headers = [cell.value for cell in first_sheet[1]] | ||
required_columns = ["Main", "Sub", "Text"] | ||
missing_columns = [col for col in required_columns if col not in headers] | ||
|
||
if missing_columns: | ||
self.message_user( | ||
request, | ||
f"Training set { training_set.name } is missing required columns: {', '.join(missing_columns)}", | ||
messages.ERROR, | ||
) | ||
|
||
return | ||
|
||
data_rows = list(first_sheet.iter_rows(min_row=2, values_only=True)) | ||
if not any(data_rows): | ||
self.message_user( | ||
request, | ||
f"The training set { training_set.name } does not contain any data rows.", | ||
messages.ERROR | ||
) | ||
return | ||
|
||
train_classifier.delay(training_set.id) | ||
|
||
self.message_user( | ||
request, | ||
"Training of the model has been initiated.", | ||
messages.SUCCESS, | ||
) | ||
|
||
|
||
class ClassifierAdmin(admin.ModelAdmin): | ||
""" | ||
Creating or disabling classifiers by hand in the Admin interface is disabled, | ||
a successful training job should create his own classifier object. | ||
""" | ||
list_display = ('name', 'precision', 'recall', 'accuracy', 'is_active', ) | ||
actions = ["activate_classifier"] | ||
readonly_fields = ('training_status', 'training_error', ) | ||
|
||
@admin.action(description="Maak deze classifier actief") | ||
def activate_classifier(self, request, queryset): | ||
""" | ||
Make chosen classifier active, disable other classifiers | ||
""" | ||
|
||
if queryset.count() > 1: | ||
self.message_user( | ||
request, | ||
"You can only make one classifier active.", | ||
messages.ERROR | ||
) | ||
return | ||
|
||
try: | ||
Classifier.objects.update(is_active=False) | ||
Classifier.objects.filter(id=queryset.first().id).update(is_active=True) | ||
|
||
self.message_user( | ||
request, | ||
f"Classifier { queryset.first().name } has been activated.", | ||
messages.SUCCESS | ||
) | ||
except Exception: | ||
self.message_user( | ||
request, | ||
f"Classifier { queryset.first().name } has not been activated.", | ||
messages.ERROR | ||
) | ||
|
||
|
||
|
||
def get_readonly_fields(self, request, obj=None): | ||
if obj: | ||
return [f.name for f in self.model._meta.fields] | ||
return [] | ||
|
||
def has_add_permission(self, request): | ||
return False | ||
|
||
def has_change_permission(self, request, obj=None): | ||
return False | ||
|
||
def has_delete_permission(self, request, obj=None): | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from django.apps import AppConfig | ||
|
||
|
||
class ClassificationConfig(AppConfig): | ||
name = 'signals.apps.classification' | ||
verbose_name = 'Classificatie management' |
Empty file.
Empty file.
19 changes: 19 additions & 0 deletions
19
app/signals/apps/classification/management/commands/train-ml.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from django.core.management.base import BaseCommand, CommandError | ||
|
||
from signals.apps.classification.models import TrainingSet | ||
from signals.apps.classification.tasks import train_classifier | ||
|
||
class Command(BaseCommand): | ||
help = "Train specific model" | ||
|
||
def add_arguments(self, parser): | ||
parser.add_argument("training_set_id", type=int) | ||
|
||
def handle(self, *args, **options): | ||
try: | ||
training_set = TrainingSet.objects.get(pk=options["training_set_id"]) | ||
except TrainingSet.DoesNotExist: | ||
raise CommandError('Training Set "%s" does not exist' % options["training_set_id"]) | ||
|
||
train_classifier(training_set.id) | ||
|
42 changes: 42 additions & 0 deletions
42
app/signals/apps/classification/migrations/0001_initial.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Generated by Django 4.2.11 on 2024-09-17 09:53 | ||
|
||
from django.db import migrations, models | ||
|
||
import signals.apps.services.domain.checker_factories | ||
import signals.apps.services.domain.mimetypes | ||
import signals.apps.services.validator.file | ||
import signals.apps.signals.models.utils | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
initial = True | ||
|
||
dependencies = [ | ||
] | ||
|
||
operations = [ | ||
migrations.CreateModel( | ||
name='TrainingSet', | ||
fields=[ | ||
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), | ||
('created_at', models.DateTimeField(auto_now_add=True)), | ||
('name', models.CharField(max_length=255)), | ||
('file', models.FileField(max_length=255, upload_to='training_sets/%Y/%m/%d/', validators=[ | ||
signals.apps.services.validator.file.MimeTypeAllowedValidator( | ||
signals.apps.services.domain.mimetypes.MimeTypeFromContentResolverFactory(), | ||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' | ||
), | ||
signals.apps.services.validator.file.MimeTypeIntegrityValidator( | ||
signals.apps.services.domain.mimetypes.MimeTypeFromContentResolverFactory(), | ||
signals.apps.services.domain.mimetypes.MimeTypeFromFilenameResolverFactory() | ||
), | ||
signals.apps.services.validator.file.ContentIntegrityValidator( | ||
signals.apps.services.domain.mimetypes.MimeTypeFromContentResolverFactory(), | ||
signals.apps.services.domain.checker_factories.ContentCheckerFactory() | ||
), | ||
signals.apps.services.validator.file.FileSizeValidator(20971520) | ||
])), | ||
], | ||
), | ||
] |
23 changes: 23 additions & 0 deletions
23
app/signals/apps/classification/migrations/0002_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Generated by Django 4.2.11 on 2024-09-17 11:09 | ||
|
||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
dependencies = [ | ||
('classification', '0001_initial'), | ||
] | ||
|
||
operations = [ | ||
migrations.CreateModel( | ||
name='Classifier', | ||
fields=[ | ||
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), | ||
('created_at', models.DateTimeField(auto_now_add=True)), | ||
('name', models.CharField(max_length=255)), | ||
('middle_model', models.FileField(max_length=255, upload_to='classification_models/middle/%Y/%m/%d/')), | ||
('middle_sub_model', models.FileField(max_length=255, upload_to='classification_models/middle_sub/%Y/%m/%d/')), | ||
], | ||
), | ||
] |
28 changes: 28 additions & 0 deletions
28
.../apps/classification/migrations/0003_classifier_accuracy_classifier_precision_and_more.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Generated by Django 4.2.11 on 2024-09-17 11:58 | ||
|
||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
dependencies = [ | ||
('classification', '0002_classifier'), | ||
] | ||
|
||
operations = [ | ||
migrations.AddField( | ||
model_name='classifier', | ||
name='accuracy', | ||
field=models.FloatField(default=0), | ||
), | ||
migrations.AddField( | ||
model_name='classifier', | ||
name='precision', | ||
field=models.FloatField(default=0), | ||
), | ||
migrations.AddField( | ||
model_name='classifier', | ||
name='recall', | ||
field=models.FloatField(default=0), | ||
), | ||
] |
46 changes: 46 additions & 0 deletions
46
app/signals/apps/classification/migrations/0004_remove_classifier_middle_model_and_more.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Generated by Django 4.2.11 on 2024-09-19 08:07 | ||
|
||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
dependencies = [ | ||
('classification', '0003_classifier_accuracy_classifier_precision_and_more'), | ||
] | ||
|
||
operations = [ | ||
migrations.RemoveField( | ||
model_name='classifier', | ||
name='middle_model', | ||
), | ||
migrations.RemoveField( | ||
model_name='classifier', | ||
name='middle_sub_model', | ||
), | ||
migrations.AddField( | ||
model_name='classifier', | ||
name='main_model', | ||
field=models.FileField(blank=True, max_length=255, null=True, upload_to='classification_models/middle/%Y/%m/%d/'), | ||
), | ||
migrations.AddField( | ||
model_name='classifier', | ||
name='sub_model', | ||
field=models.FileField(blank=True, max_length=255, null=True, upload_to='classification_models/middle_sub/%Y/%m/%d/'), | ||
), | ||
migrations.AlterField( | ||
model_name='classifier', | ||
name='accuracy', | ||
field=models.FloatField(blank=True, default=0, null=True), | ||
), | ||
migrations.AlterField( | ||
model_name='classifier', | ||
name='precision', | ||
field=models.FloatField(blank=True, default=0, null=True), | ||
), | ||
migrations.AlterField( | ||
model_name='classifier', | ||
name='recall', | ||
field=models.FloatField(blank=True, default=0, null=True), | ||
), | ||
] |
28 changes: 28 additions & 0 deletions
28
...assification/migrations/0005_classifier_is_active_alter_classifier_main_model_and_more.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Generated by Django 4.2.15 on 2024-09-27 10:25 | ||
|
||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
dependencies = [ | ||
('classification', '0004_remove_classifier_middle_model_and_more'), | ||
] | ||
|
||
operations = [ | ||
migrations.AddField( | ||
model_name='classifier', | ||
name='is_active', | ||
field=models.BooleanField(default=False), | ||
), | ||
migrations.AlterField( | ||
model_name='classifier', | ||
name='main_model', | ||
field=models.FileField(blank=True, max_length=255, null=True, upload_to='classification_models/main/%Y/%m/%d/'), | ||
), | ||
migrations.AlterField( | ||
model_name='classifier', | ||
name='sub_model', | ||
field=models.FileField(blank=True, max_length=255, null=True, upload_to='classification_models/main_sub/%Y/%m/%d/'), | ||
), | ||
] |
Oops, something went wrong.