Extensive Multilabel Classification of Brain MRI Scans for Infarcts Using the Swin UNETR Architecture in Deep Learning Applications
Article information
Abstract
Objective
To distinguish infarct location and type with the utmost precision using the advantages of the Swin UNEt TRansformers (Swin UNETR) architecture.
Methods
The research employed a two-phase training approach. In the first phase, the Swin UNETR model was trained using the Ischemic Stroke Lesion Segmentation Challenge (ISLES) 2022 dataset, which included cases of acute and subacute infarcts. The second phase involved training with data from 309 patients. The 110 categories result from classifying infarcts based on 22 specific brain regions. Each region is divided into right and left sides, and each side includes four types of infarcts (acute, acute lacunar, subacute, subacute lacunar). The unique architecture of Swin UNETR, integrating elements of both the transformer and u-net designs with a hierarchical transformer computed with shifted windows, played a crucial role in the study.
Results
During Swin UNETR training with the ISLES 2022 dataset, batch loss decreased to 0.8885±0.1897, with training and validation dice scores reaching 0.4224±0.0710 and 0.4827±0.0607, respectively. The optimal model weight had a validation dice score of 0.5747. In the patient data model, batch loss decreased to 0.0565±0.0427, with final training and validation accuracies of 0.9842±0.0005 and 0.9837±0.0010.
Conclusion
The results of this study surpass the accuracy of similar studies, but they involve the issue of overfitting, highlighting the need for future efforts to improve generalizability. Such detailed classifications could significantly aid physicians in diagnosing infarcts in clinical settings.
INTRODUCTION
Evolution of deep learning architectures-from perceptrons to Swin UNEt TRansformers
Artificial neurons are designed to mimic the way biological neurons fire signals when they receive a sufficient number of inputs from other neurons [1]. In 1943, McCulloch and Pitts [1] introduced an early artificial neuron model, known as threshold logic units (TLUs). This model compares a weighted sum of input signals to a threshold to determine the neuron’s output, marking the beginning of artificial neural networks [1]. In 1957, Rosenblatt [2] invented the perceptron using a modified TLU that applies a step function to the weighted sum of the inputs. To address the limitations of perceptron highlighted by Minsky and Papert [3], particularly their inability to solve exclusive classification problems, a multilayer perceptron (MLP) were developed by stacking multiple perceptrons [4]. Rumelhart et al. [4] developed the groundbreaking backpropagation algorithm, and convolutional neural networks (CNNs) were developed, inspired by the consecutive layered structure of neurons, the optic nerve’s local receptive fields [5]. For years, deep CNNs have led in various visual recognition tasks [6]. However, they have limitations in learning long-range dependencies, crucial for segmenting lesions of different shapes and sizes [7] and for excellent output results, supervised training of a large network and millions of parameters are required. But in the medical field, it was difficult to obtain even thousands of training images [8]. Fully convolutional network, also called u-net, was proposed to compensate for the shortcomings of difficulties. It features a contracting path for context capture and a symmetric expanding path for precise localization, maximizing the use of annotated samples through data augmentation [8].
In 2017, Google researchers introduced the transformer architecture, revolutionizing neural machine translation using attention mechanisms instead of convolutional layers [9]. Building on this, Hatamizadeh et al. [5] reformulated volumetric medical image segmentation as a sequence-to-sequence prediction problem, introducing UNEt TRansformers (UNETR), which combines a transformer encoder with a U-shaped network. This structure effectively captures multi-scale patterns and delivers precise semantic segmentation. To overcome the challenge arising from the difference between language and vision, such as large variations in the scale of visual entities and the high resolution of pixels in images, Liu et al. [10] proposed a hierarchical transformer using shifted windows. Inspired by the success of vision transformers, Hatamizadeh et al. [7] propose a novel segmentation model termed Swin UNEt TRansformers (Swin UNETR). Swin UNETR’s hierarchical transformer efficiently processes high-resolution images, capturing the fine details necessary for accurate lesion classification in brain magnetic resonance imaging (MRI). Its U-shaped design with skip connections enables exact lesion localization, while the integration of attention mechanisms significantly enhances feature focus, crucial for identifying different lesion types.
The necessity of this study
In the acute phase of stroke care, determining the location of the infarct is essential for clinical decisions, such as patient triage, stroke mechanism investigation, and additional therapies [11]. While lesion–symptom mapping studies have advanced our understanding of brain-behavior relationships, modern imaging techniques have now surpassed the capabilities of traditional lesion analysis [12]. Voxel-based lesion-symptom mapping (VLSM) techniques involve comparing neurobehavioral scores across patients by analyzing the presence or absence of lesions on a voxel-by-voxel basis [12]. VLSM methods have proven useful in elucidating the impact of infarct location on motor [13], language [14], and cognitive recovery [15]. For example, location of the infarct was the most significant factor influencing positive cognitive results [15]. Given the significance of the infarct location, this study aimed to distinguish infarct location and type with the utmost precision using the advantages of the Swin UNETR architecture. Therefore, based on radiological interpretations, the datasets were divided into 110 categories, and a multilabel was assigned to allow for multiple categories to be applicable to a single case. Due to the limited resources available as an individual researcher, this study was planned within the scope feasible through Google Colab [16].
METHODS
This study was conducted in two phases using the Ischemic Stroke Lesion Segmentation Challenge (ISLES) 2022 open dataset and brain MRI data from 309 patients. First, we trained the Swin UNETR model on the ISLES 2022 dataset to create segmentation masks, focusing on differentiating infarctions from normal parenchyma. The model was trained using data from actual patients with official radiological interpretations to learn how to classify the regions already identified as infarcts.
Training Swin UNETR with the ISLES 2022 dataset
ISLES 2022 dataset
ISLES is a specialized competition to promote and advance the methods for automated segmentation of ischemic stroke lesions, with a slightly different purpose since 2015. The 2022 challenge concentrated on the segmentation of infarcts in multimodal MRI scans, encompassing both acute strokes (0 to 7 days post-onset) and sub-acute strokes (1 to 3 weeks post-onset) [17]. It aimed to identify not only large infarct lesions but also multiple embolic infarcts across 400 cases [17]. Participants included were 18 years or older and had undergone brain MR imaging for diagnosed or suspected stroke and the imaging included at least fluid-attenuated inversion recovery (FLAIR) and diffusion-weighted imaging (DWI) sequences [17]. Image acquisition was performed on one of the following devices: Philips Achieva 3T MRI scanner (Philips Healthcare), Philips Ingenia 3T MRI scanner (Philips Healthcare), Siemens Verio 3T MRI scanner (Siemens Healthineers), Siemens MAGNETOM Avanto 1.5T MRI scanner (Siemens Healthineers), or Siemens MAGNETOM Aera 1.5T MRI scanner (Siemens Healthineers) [17]. The ISLES 2022 datasets were divided into training, validation, and inference datasets in an 8:1:1 ratio in this study. Data augmentation techniques used included random image rotation by 90 degrees, flipping randomly across the axial, sagittal, and coronal planes, and random adjustments to scale and intensity. Permission to use this data was granted by the ISLES 2022 organizer, Ezequiel de la Rosa.
Swin UNETR implementations
Swin UNETR is implemented using PyTorch-Ignite [18] and MONAI [19], trained on Google Colab [16] with A100 graphics processing units. The model was created by using a feature size of 48, which is compatible with the self-supervised pre-trained weights [20] and the Swin UNETR encoder was initialized from pre-trained weights [21]. Training involved the adaptive moment estimation with decoupled weight decay optimizer [22] with an initial learning rate of 1e-5 and a weight decay of 1e-1, using dice focal loss for about 24,000 iterations. We set the maximum epochs to 400 and configured the patience to 40, the number of events to wait if there is no improvement before stopping the training [23]. This approach ensures the training halts when further improvements are unlikely, helping to prevent overfitting and save computational resources. A detailed description of the Swin UNETR architecture is provided in Supplementary Table S1.
Loss function
The training loss function used was dice focal loss, which computes both dice loss (DL) and focal loss, and return the weighted sum of these two losses, calculated voxel-wise [24]. Focal loss is an extension of binary cross entropy (BCE) loss between the target and the input probabilities [24]. The dice coefficient (DICE) for the binary classification can be written as:
The variable git, which can be either 0 or 1, denotes the ground truth label of class t for pixel i, with N representing the total number of pixels in the image [25]. The variable pit, represents the output probability, indicating that it falls within the range of 0 to 1 and the term ε is employed to prevent the numerical issue of division by zero [25]. The DL formula is as follows, where w represents the weight ωt corresponding to each class t [25].
The focal dice loss (FDL) is defined by applying a factor of 1/β as the exponent to DICEt for each class, where the exponent parameter β is greater than or equal to 1 [25].
Training loss was recorded every 100 iterations, and at the end of every epoch.
Evaluation metrics
Mean dice was used as the evaluation metric, calculating the dice score from full-size tensors and averaging over batches, class-channels, and iterations [26]. Train mean dice score and validation mean dice score were calculated using the validation and inference datasets.
Training a classification model with patients’ brain MRI data
Patients data
We included subjects aged 18 years or older who received MR imaging for diagnosed stroke at Pohang SeMyeong Christianity Hospital. Images were acquired using Siemens MAGNETOM Vida 3T MRI scanner (Siemens Healthineers) and included DWI, apparent diffusion coefficient, and FLAIR modalities in all cases. The ISLES 2022 dataset focuses exclusively on segmentation for acute and subacute infarcts, thereby excluding chronic and old infarcts. Cases presenting only with intracranial hemorrhage, subarachnoid hemorrhage, intraventricular hemorrhage, or other lesions, without infarct lesions, were also omitted. This study aimed to classify as meticulously as possible based on the interpretations of the Department of Radiology. For instance, although the regions indicated by the basal ganglia, caudate, and lenticulostriate arteries territory overlapped, they were distinguished in radiological interpretations, leading to their establishment as one of total 110 independent categories (Table 1).
In the categories, borderzone infarct and lacunar infarct followed the definitions set by the Department of Radiology. Borderzone infarct is characterized as ischemic lesions typically found at the intersection of two major arterial territories [27] and lacunar infarct are defined as small brain lesions (0.2 to 15 mm3) [28].
The datasets from 309 patients were randomly divided into training, validation, and inference datasets in an 8:1:1 ratio, with data augmentation techniques similar to those used for the ISLES 2022 dataset. As the objective of this study is multilabel classification, all categories relevant to a single case were assigned, and subsequently, these were transformed into one-hot encoded vectors for further processing.
In this study, we confirm that the use of this clinical data was in full compliance with the ethical standards of the Pohang SeMyeong Christianity Hospital’s Institutional Review Board (IRB no. PSMCHIRB-2024-16-1), ensuring adherence to necessary protocols for patient privacy and data security.
A classification model implementations
A classification model employs a Swin UNETR and the weights demonstrating the best checkpoint mean dice score of 0.5757 were loaded and utilized. The model inputs are tailored for image regions of interest and is designed to classify images into a predefined number of classes. Post-transformer processing includes an adaptive average pooling and layer normalization sequence, ensuring a consistent and stabilized feature representation before classification. The final classification is achieved through a fully connected linear layer, mapping the extracted and normalized features to the respective class probabilities, signifying the model’s capacity for precise image-based classification. The maximal epochs were set to 400, and the patience was configured to 50.
Loss function
Our chosen loss function is BCE loss between the target and the input probabilities [24], calculated voxel-wise.
Where M: the number of training samples, ym: ground truth label for training sample m, xm: input for training sample m, hθ: hypothesis(model) with weights θ [29]. Training loss was recorded every 5 iterations, and train and validation accuracies were calculated using the respective datasets at the end of each epoch.
Evaluation metrics
We calculated the true positive (TP), true negative (TN), false positive (FP), and false negative (FN) values by comparing the predictions of the classification model with the ground truth labels. Accuracy for our multilabel data was derived as [30]:
RESULTS
Outcomes of training Swin UNETR using the ISLES 2022 dataset
During training of Swin UNETR with the ISLES 2022 dataset, the batch loss showed considerable fluctuation over 23,200 iterations but exhibited a decreasing trend as epochs progressed, reaching a value of 0.8885±0.1897 (Table 2, Fig. 1A). The training mean dice score was 0.4224±0.0710, indicating a general upward trend (Table 2, Fig. 1B). The validation mean dice score reached 0.4827±0.0607 (Table 2, Fig. 1C), increasing with the progression of epochs but then declining after 100 epochs. Based on these results, we selected the model weights with a validation mean dice score of 0.5747 as the best for further training and validation on actual patient brain MRI images. Training and validation were halted by an early stopping function set with a patience of 50 epochs.
Outcomes of training a classification model using the patients’ dataset
For the classification model trained with patient data, batch loss significantly decreased from the start and then plateaued after 1,000 iterations, reaching 0.0565±0.0427 (Table 3, Fig. 2A). Training accuracy remained high, close to the maximum value for up to 25 epochs, then slightly decreased, showing a final value of 0.9842±0.0005 (Table 3, Fig. 2B). Validation accuracy similarly stayed high up to 21 epochs before decreasing, with a final value of 0.9837±0.0010 (Table 3, Fig. 2C). The training and validation phases were also stopped using an early stopping function with a patience of 50 epochs.
DISCUSSION
In summary, the Swin UNETR-based classification model, which was trained on brain MRI scans for infarct segmentation and categorized into 110 distinct classes, achieved an accuracy of 0.9837±0.0010. This significantly surpasses the accuracies reported in similar studies. For instance, Subudhi et al. [31] achieved an accuracy of 90.23% in classifying three types of stroke according to The Oxfordshire Community Stroke Project. Additionally, Cetinoglu et al. [32] reported a 93% accuracy in classifying three vascular territories (anterior cerebral artery, middle cerebral artery, watershed).
The superior performance of Swin UNETR over MobileNetV2 and EfficientNet-B0 CNN models, used in the two studies above, can be attributed to two main factors. Firstly, architectural differences are critical. Cetinoglu et al. [32] used modified versions of MobileNetV2 and EfficientNet-B0, these fully convolutional neural networks have limitations in modeling long-range information due to their convolution layers’ limited kernel size. In contrast, Swin UNETR outperformed previous winning methodologies like 3D segmentation network with residual connections [33], no-new-Net [34], and Multimodal Brain Tumor Segmentation Using Transformer [35] in the research by Hatamizadeh et al. [7]. Secondly, Swin UNETR’s pretraining on a specialized dataset of computed tomography (CT) scans likely contributed to its enhanced accuracy. The pretrained Swin UNETR encoder used weights from a cohort of 5,050 CT scans [20], whereas MobileNetV2 and EfficientNet-B0 were pretrained with the ImageNet dataset [36], a general visual database not specific to medical images. This difference in pretraining context and specificity is believed to have played a role in the improved accuracy.
However, our study has limitations. The decrease in validation accuracy after 20 epochs, as observed in Fig. 2, suggests overfitting. This issue is somewhat inevitable given the large number of classification categories (110) relative to the small number of patient MRI data (N=309). In such cases, while different categories were delineated based on radiological interpretations, merging classes that share overlapping areas could be attempted to reduce the total number of classes, thereby potentially overcoming overfitting. However, as previously stated in our research objectives, as a physiatrist dealing with brain imaging and patient symptoms, our primary goal was to distinguish infarct location and type with the utmost precision. And due to time and resource constraints, we were unable to perform hyperparameter tuning, just as we could not engage in the repetitive tasks of adjusting and experimenting with the number of categories. Future work, including acquiring more MRI data and employing regularization, dropout, and other techniques, is expected to improve generalizability.
Additionally, training exclusively on the ISLES 2022 dataset, which focuses on acute and subacute infarcts, did not include chronic or old infarcts. With the acquisition of more patient data in the future, it would be possible to conduct studies on chronic and old lesions.
Currently, there are several commercially available AI-based software solutions for stroke, demonstrating diagnostic accuracies for brain infarcts with MRI and other imaging techniques, with sensitivities and specificities of 44% to 83% and 57% to 93% for tissue ischemia, 80% to 96% and 90% to 98% for large vessel occlusion, and 96% and 95% for hemorrhage detection [37]. For instance, in detecting and quantifying ischemic core and penumbra, the e-ASPECTS software can outperform non-stroke experts and is at least as effective as stroke experts in applying the Alberta Stroke Programme Early CT Score (ASPECTS) for patients with acute ischemic stroke [38]. This paper is based on open datasets and brain MRI scans from 400 actual patients who visited the emergency room at Pohang SeMyeong Christianity Hospital. Although this is a relatively small dataset for deep learning applications, we achieved excellent results through augmentation techniques. Rehabilitation medicine, which prioritizes both brain MRI images and patients’ symptoms, is ideally suited for such research, but clinical physicians’ participation in studies is low. This paper will spark future interest and research among clinicians by presenting the data collection methods and sample sizes to those unfamiliar with deep learning and uncertain about how to begin research in this area. Embracing deep learning in rehabilitation medicine can significantly transform patient care. It enhances diagnostic accuracy, personalizes treatment plans, improves efficiency, and ultimately leads to better patient outcomes.
Conclusion
The classification model based on Swin UNETR architecture achieved high accuracy for 110 classes of acute and subacute brain infarcts through a two-stage learning process. This study highlights the potential for achieving detailed classification using open datasets as a supplement to limited patient data. This method is especially useful in scenarios where segmentation is constrained by resources. Such detailed classifications could significantly aid physicians in diagnosing infarcts in clinical settings.
Notes
CONFLICTS OF INTEREST
No potential conflict of interest relevant to this article was reported.
FUNDING INFORMATION
None.
AUTHOR CONTRIBUTION
Conceptualization: Oh J. Methodology: Oh J, An H. Formal analysis: Oh J. Project administration: Oh J, An H. Visualization: Oh J. Writing – original draft: Oh J. Writing – review and editing: Oh J, An H. Approval of final manuscript: all authors.
SUPPLEMENTARY MATERIALS
Supplementary materials can be found via https://doi.org/10.5535/arm.230029.