Simple DistilBERT Model Training and Inferring Workflow for Paragraph Sorting with a Labeling Interface (GUI + CLI)
Ondrej Špánik © 2024-09
- GUI Workflow
- CLI Package Setup (Optional)
- Model Settings
- Stage 1: Training Dataset (Manual Labeling)
- Stage 2: Inferred Dataset (Automated Labeling)
- Stage 3: Validation (Optional)
- Useful Links
You can launch the GUI without prior package setup by navigating to the repository folder and using:
python3 gui.py
The only prerequisite is installed Python3 (ideally latest version).
The GUI provides a user-friendly interface for managing the entire process of training, inferring, and validating the DistilBERT model for paragraph sorting.
Here's an overview of the main components:
-
Environment Setup:
- The GUI automatically creates a virtual environment and installs required packages.
- It checks for existing environment installation to avoid redundant setups.
-
Training Section:
- Allows users to choose or generate label files (train_labels.json).
- Provides an interface to assign labels to paragraphs using a browser-based tool.
- Enables selection of the training input file (train_input.json).
- Includes a button to start the training process.
-
Inferring Section:
- Offers options to choose or create an input file for inference (infer_input.md).
- Allows users to start the inferring process on the prepared input.
- Provides a checkbox to output only the top label for each paragraph.
-
Validation Section:
- Includes a button to start the validation process.
- Displays validation results in a text area.
- Allows saving of validation output to a file.
-
File Management:
- The GUI continuously checks for the existence of required files and updates button states accordingly.
- Provides buttons to open and view various input and output files.
-
Help Menu:
- Includes a README option for quick access to help documentation.
After cloning this repository and navigating to it, make sure to create and activate a virtual environment using python3 -m venv venv
and source venv/bin/activate
in order to use Python packages locally within the repository's context, then install them using pip3 install -r requirements.txt
.
Currently the DistilBERT training is setup as follows (proved to be most efficient with these values):
- save_total_limit: 1 (Keep only the last checkpoint)
- learning_rate: 2e-5
- per_device_train_batch_size: 8
- num_train_epochs: 16
- weight_decay: 0.01
If you wish to use different values, adjust distilbert_train.py
manually.
In this stage, we input a manually labeled dataset to the model for training, enabling it to infer labels for an automated dataset in the second stage.
train_labels.json
train_input.json
-
Create
train_labels.json
manually in the following format:{ "0": {"label": "poetry", "color": "#00FF00", "increase_if": [], "decrease_if": [], "must_have": []}, "1": {"label": "description", "color": "#FFA500", "increase_if": [], "decrease_if": [], "must_have": []}, "2": {"label": "spiritual", "color": "#FFD700", "increase_if": [], "decrease_if": [], "must_have": ["god", "jesus", "religion"]}, "3": {"label": "sadness", "color": "#4169E1", "increase_if": [], "decrease_if": [], "must_have": []}, "4": {"label": "psychology", "color": "#800080", "increase_if": [], "decrease_if": [], "must_have": []} }
-
Prepare text with paragraphs for manual sorting (training).
-
Use
labeling.html
to loadtrain_labels.json
and sort the paragraphs into buckets. Launch command for Google Chrome:chrome labeling.html
-
Export the buckets as
train_input.json
. -
Run
distilbert_train.py
to train the model ontrain_input.json
. This will generate the model's directory. Note that the training is with validation/testing on the same data (edit the script yourself for improvements). Launch command:python3 distilbert_train.py
After training the model, we use it to infer labels for an unlabeled dataset, which can be the rest of your data that hasn't been manually labeled yet.
infer_input.md
-
Prepare text with paragraphs for automated (inferred) sorting.
-
Input the text in a clear format into
infer_input.md
. -
Run
distilbert_infer.py
to execute the model. This will generate the output fileinfer_output.json
. Launch command:py distilbert_infer.py [options]
Available options:
-c
or--no-normalize
: Do not normalize scores-l
or--top-label
: Only output the top label name-s
or--skip-long
: Automatically skip paragraphs exceeding maximum length
Example:
- To run with default settings:
python3 distilbert_infer.py
- To run without score normalization:
python3 distilbert_infer.py -c
- To output only the top label:
python3 distilbert_infer.py -l
- To combine options:
python3 distilbert_infer.py -c -l
- To skip long paragraphs:
python3 distilbert_infer.py -s
ß
In this stage, we validate the output of the inference process by comparing it to the original training data and analyze the distribution of predicted labels.
train_input.json
infer_output.json
train_labels.json
-
Ensure you have the required files from the previous stages.
-
Run
distilbert_validate.py
to perform the validation. This script will:- Compare the assigned labels in
infer_output.json
to the labels intrain_input.json
. - Print any mistakes found during the comparison.
- Calculate and display metrics including precision, recall, and F1 score for the top label.
- Calculate and display overall metrics for all labels.
- Show the percentage of correct and incorrect predictions.
- Calculate and display the percentage distribution of predicted labels.
Launch command:
python3 distilbert_validate.py
- Compare the assigned labels in
-
Review the output, which will include:
- A list of mistakes (if any) showing the text, predicted label, and correct label.
- Top Label Metrics (Precision, Recall, F1 Score).
- Overall Metrics (Precision, Recall, F1 Score).
- Prediction Accuracy (percentage of correct and incorrect predictions).
- Total number of mistakes and total number of texts in the training dataset.
- Label Percentages showing the distribution of predicted labels.
This validation stage helps assess the model's performance, identify areas for improvement in the training process, and understand the distribution of labels in the inferred dataset.