I’m training a DETR object detection model on a custom dataset. I’m following this roboflow article;
How to Train RT-DETR on a Custom Dataset with Transformers
But when I run the Trainer API with custom_metrics
, I get the error “CUDA Out of Memory”. I have reduced batch_size from 16 until 1, but the same error of “Out of memory”. Here’s how I’m creating the custom_metrics function
id2label = {id: label for id, label in enumerate(train_ds.classes)}
label2id = {label: id for id, label in enumerate(train_ds.classes)}
@dataclass
class ModelOutput:
logits: torch.Tensor
pred_boxes: torch.Tensor
class MAPEvaluator:
def __init__(self, image_processor, threshold=0.00, id2label=None):
self.image_processor = image_processor
self.threshold = threshold
self.id2label = id2label
def collect_image_sizes(self, targets):
"""Collect image sizes across the dataset as list of tensors with shape [batch_size, 2]."""
image_sizes = []
for batch in targets:
batch_image_sizes = torch.tensor(np.array([x["size"] for x in batch]))
image_sizes.append(batch_image_sizes)
return image_sizes
def collect_targets(self, targets, image_sizes):
post_processed_targets = []
for target_batch, image_size_batch in zip(targets, image_sizes):
for target, (height, width) in zip(target_batch, image_size_batch):
boxes = target["boxes"]
boxes = sv.xcycwh_to_xyxy(boxes)
boxes = boxes * np.array([width, height, width, height])
boxes = torch.tensor(boxes)
labels = torch.tensor(target["class_labels"])
post_processed_targets.append({"boxes": boxes, "labels": labels})
return post_processed_targets
def collect_predictions(self, predictions, image_sizes):
post_processed_predictions = []
for batch, target_sizes in zip(predictions, image_sizes):
batch_logits, batch_boxes = batch[1], batch[2]
output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes))
post_processed_output = self.image_processor.post_process_object_detection(
output, threshold=self.threshold, target_sizes=target_sizes
)
post_processed_predictions.extend(post_processed_output)
return post_processed_predictions
@torch.no_grad()
def __call__(self, evaluation_results):
predictions, targets = evaluation_results.predictions, evaluation_results.label_ids
image_sizes = self.collect_image_sizes(targets)
post_processed_targets = self.collect_targets(targets, image_sizes)
post_processed_predictions = self.collect_predictions(predictions, image_sizes)
evaluator = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
evaluator.warn_on_many_detections = False
evaluator.update(post_processed_predictions, post_processed_targets)
metrics = evaluator.compute()
# Replace list of per class metrics with separate metric for each class
classes = metrics.pop("classes")
map_per_class = metrics.pop("map_per_class")
mar_100_per_class = metrics.pop("mar_100_per_class")
for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
class_name = id2label[class_id.item()] if id2label is not None else class_id.item()
metrics[f"map_{class_name}"] = class_map
metrics[f"mar_100_{class_name}"] = class_mar
metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
return metrics
eval_compute_metrics_fn = MAPEvaluator(image_processor=processor, threshold=0.01, id2label=id2label)
Below is the model training on the custom dataset;
training_args = TrainingArguments(
output_dir=f"Malaria-finetune",
report_to="none",
num_train_epochs=10,
max_grad_norm=0.1,
learning_rate=5e-5,
warmup_steps=300,
per_device_train_batch_size=1,
dataloader_num_workers=2,
metric_for_best_model="eval_map",
greater_is_better=True,
load_best_model_at_end=True,
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=2,
remove_unused_columns=False,
eval_do_concat_batches=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=pytorch_dataset_train,
eval_dataset=pytorch_dataset_valid,
processing_class=processor,
data_collator=collate_fn,
compute_metrics=eval_compute_metrics_fn
)
trainer.train()
When I run the trainer, I get the error; “OutOfMemoryError: CUDA out of memory.”
With research, it’s most likely the compute_metrics
that’s causing the error. Please help