Skip to contents

Trains a CNN model using the Keras framework via keras3 to classify landscapes based on their spatial patterns (pixel data). The function uses a multiscale CNN architecture optimized for distinguishing different landscape patterns.

Usage

train_nn_pixels(
  landscapes,
  cv_method = "k-fold",
  cv_folds = 5,
  epochs = 50,
  batch_size = 16,
  learning_rate = 0.001,
  architecture = "multiscale",
  dropout_rate = 0.3,
  dense_units = 128,
  model_path = NULL,
  loss = "categorical_crossentropy",
  optimizer = "adam",
  metrics = c("accuracy"),
  validation_split = 0,
  callbacks = NULL,
  patience = 15,
  verbose = TRUE
)

Arguments

landscapes

List. List of landscape objects created by create_landscape or create_landscapes. **Note**: Input landscapes must contain categorical/discrete habitat data (e.g., 0/1 for two habitat types, or 0/1/2 for three types). Continuous data (e.g., elevation, gradients) is not supported. All landscapes used for training are required to have the same spatial extent.

cv_method

Character. Cross-validation method: "none", "k-fold", "loo" (default: "k-fold").

  • "k-fold" or "loo": Performs cross-validation and returns performance metrics

  • "none": Trains on ALL provided data without validation. Use apply_nn_pixels with a separate test set to evaluate performance.

cv_folds

Integer. Number of cross-validation folds when cv_method="k-fold" (default: 5). Note: May be automatically reduced to ensure adequate samples per fold.

epochs

Integer. Number of training epochs (default: 50).

batch_size

Integer. Batch size for training (default: 16).

learning_rate

Numeric. Learning rate for Adam optimizer (default: 0.001).

architecture

Character. CNN architecture (default: "multiscale"). Currently only "multiscale" is supported, which uses multiple kernel sizes (3x3 and 5x5) to capture patterns at different spatial scales.

dropout_rate

Numeric. Dropout rate for regularization (0-1, default: 0.3). Higher values reduce overfitting but may decrease model capacity. Applied between convolutional and dense layers.

dense_units

Integer. Number of units in the final dense layer before output (default: 128). Controls model capacity for learning complex pattern combinations.

model_path

Character. Path to save model. Models are saved as `.keras` files. (default: NULL means model is not saved).

loss

Character. Loss function for training (default: "categorical_crossentropy"). Use "sparse_categorical_crossentropy" if labels are integers rather than one-hot encoded. See loss_categorical_crossentropy for details.

optimizer

Character. Optimizer algorithm: "adam" (default), "sgd", "rmsprop". Adam is recommended for most cases. See optimizer_adam. Note: Advanced optimizer parameters (e.g., momentum, beta values) are not currently exposed.

metrics

Character vector. Metrics to track during training (default: c("accuracy")). Additional options: "categorical_accuracy", "top_k_categorical_accuracy". Does not affect training, only monitoring. See compile.

validation_split

Numeric. Fraction of training data to use as validation set during final model training and passed to fit(0-1, default: 0). When > 0, enables monitoring and early stopping on validation loss. Particularly useful when cv_method="none" to prevent overfitting. Ignored during CV fold training (which uses its own validation splits).

callbacks

List. Optional keras callbacks for advanced training control (default: NULL). Examples: early stopping, learning rate scheduling, model checkpointing. Note: Only applies to final model training. CV folds always use patience-based early stopping if patience is specified. For an overview of available callbacks, see callback_early_stopping (the callback used by default) and related `callback_` functions.

patience

Integer. Number of epochs with no improvement before early stopping (default: 15). Applied to both CV fold training (monitors validation loss) and final model training (monitors validation loss if `validation_split` > 0). Only used when callbacks=NULL. Set to NULL to train for full epoch count without early stopping. Is passed to callback_early_stopping.

verbose

Logical. Show training progress and performance summaries (default: TRUE). When TRUE, displays epoch-by-epoch training/validation metrics during final model training, plus CV fold accuracies and final performance summaries. CV fold epoch details are not shown. When FALSE, runs silently.

Value

List containing:

model

Trained keras model object

history

Training history object from keras3::fit()

classes

Character vector of class names used during training

input_shape

Integer vector of input dimensions (height, width, channels)

architecture

Character, architecture type used ("multiscale")

performance

Performance metrics. When cv_method != "none", contains results from evaluate_cv_performance() including confusion matrix, per-class metrics, and overall accuracy. When cv_method = "none", contains training metadata only (see note field for evaluation instructions).

See also

apply_nn_pixels

Other neural network training: set_random_seed(), train_nn_metrics()

Examples

# \donttest{
# Create training data
training_landscapes <- create_landscapes(
  n = 200,
  patterns = c("sharp", "diffuse", "clustered", "fingers", "bands", "random")
)
#>  Successfully generated all 200 training landscapes

# Train with cross-validation
model <- train_nn_pixels(
  landscapes = training_landscapes,
  cv_method = "k-fold",
  cv_folds = 5
)
#> 
#> ── Landscape type distribution: ──
#> 
#> training_labels
#>     bands clustered   diffuse   fingers    random     sharp 
#>        33        33        34        33        33        34 
#> ── Cross-validation (k-fold, 5 folds) ──
#> 
#>  Fold 1/5 accuracy: 0.8333
#>  Fold 2/5 accuracy: 0.8333
#>  Fold 3/5 accuracy: 0.8095
#>  Fold 4/5 accuracy: 0.7895
#>  Fold 5/5 accuracy: 0.7222
#> 
#> ── Cross-validation results ──
#> 
#>  Method: 5-fold cross-validation
#>  Overall accuracy: 80%
#> 
#> ── Confusion matrix 
#>            Actual
#> Predicted   bands clustered diffuse fingers random sharp
#>   bands        26         2       1       2      0     0
#>   clustered     1        19       0       7      0     0
#>   diffuse       4         0      32       0      0     0
#>   fingers       2         9       0      20      0     4
#>   random        0         0       1       0     33     0
#>   sharp         0         3       0       4      0    30
#> 
#> ── Per-class performance 
#> # A tibble: 6 × 5
#>   class     count recall precision f1_score
#>   <chr>     <int>  <dbl>     <dbl>    <dbl>
#> 1 bands        33   0.79      0.84     0.81
#> 2 clustered    33   0.58      0.7      0.63
#> 3 diffuse      34   0.94      0.89     0.91
#> 4 fingers      33   0.61      0.57     0.59
#> 5 random       33   1         0.97     0.99
#> 6 sharp        34   0.88      0.81     0.85
#> 
#> ── Accuracy and loss across folds ──
#> 
#> Mean accuracy: 0.7976 +- 0.0459
#> Mean loss: 0.5956 +- 0.1579
#> 
#> 
#> ── Training final model on all data (validation split: 0) ──
#> 
#> Epoch 1 - loss: 1.9607 - accuracy: 0.1950
#> Epoch 2 - loss: 1.3769 - accuracy: 0.4950
#> Epoch 3 - loss: 0.5326 - accuracy: 0.8250
#> Epoch 4 - loss: 0.2971 - accuracy: 0.8950
#> Epoch 5 - loss: 0.1829 - accuracy: 0.9450
#> Epoch 6 - loss: 0.1632 - accuracy: 0.9650
#> Epoch 7 - loss: 0.1692 - accuracy: 0.9550
#> Epoch 8 - loss: 0.0680 - accuracy: 0.9800
#> Epoch 9 - loss: 0.0101 - accuracy: 1.0000
#> Epoch 10 - loss: 0.0019 - accuracy: 1.0000
#> Epoch 11 - loss: 0.0010 - accuracy: 1.0000
#> Epoch 12 - loss: 0.0010 - accuracy: 1.0000
#> Epoch 13 - loss: 0.0004 - accuracy: 1.0000
#> Epoch 14 - loss: 0.0004 - accuracy: 1.0000
#> Epoch 15 - loss: 0.0003 - accuracy: 1.0000
#> Epoch 16 - loss: 0.0002 - accuracy: 1.0000
#> Epoch 17 - loss: 0.0003 - accuracy: 1.0000
#> Epoch 18 - loss: 0.0002 - accuracy: 1.0000
#> Epoch 19 - loss: 0.0002 - accuracy: 1.0000
#> Epoch 20 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 21 - loss: 0.0002 - accuracy: 1.0000
#> Epoch 22 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 23 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 24 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 25 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 26 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 27 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 28 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 29 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 30 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 31 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 32 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 33 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 34 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 35 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 36 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 37 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 38 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 39 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 40 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 41 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 42 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 43 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 44 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 45 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 46 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 47 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 48 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 49 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 50 - loss: 0.0000 - accuracy: 1.0000

# Train without cross validation on all data
final_model <- train_nn_pixels(
  landscapes = training_landscapes,
  cv_method = "none",
  epochs = 100
)
#> ── Landscape type distribution: ──
#> 
#> training_labels
#>     bands clustered   diffuse   fingers    random     sharp 
#>        33        33        34        33        33        34 
#> ── Training final model on all data ──
#> 
#>  Training on all data (validation split is 0)...
#> Epoch 1 - loss: 1.5645 - accuracy: 0.3450
#> Epoch 2 - loss: 1.0592 - accuracy: 0.5200
#> Epoch 3 - loss: 0.6512 - accuracy: 0.7200
#> Epoch 4 - loss: 0.4895 - accuracy: 0.7850
#> Epoch 5 - loss: 0.3321 - accuracy: 0.8700
#> Epoch 6 - loss: 0.1418 - accuracy: 0.9650
#> Epoch 7 - loss: 0.1040 - accuracy: 0.9600
#> Epoch 8 - loss: 0.0261 - accuracy: 0.9850
#> Epoch 9 - loss: 0.0285 - accuracy: 0.9950
#> Epoch 10 - loss: 0.3093 - accuracy: 0.9500
#> Epoch 11 - loss: 0.1795 - accuracy: 0.9350
#> Epoch 12 - loss: 0.0480 - accuracy: 0.9900
#> Epoch 13 - loss: 0.0136 - accuracy: 0.9950
#> Epoch 14 - loss: 0.0083 - accuracy: 1.0000
#> Epoch 15 - loss: 0.0014 - accuracy: 1.0000
#> Epoch 16 - loss: 0.0004 - accuracy: 1.0000
#> Epoch 17 - loss: 0.0004 - accuracy: 1.0000
#> Epoch 18 - loss: 0.0002 - accuracy: 1.0000
#> Epoch 19 - loss: 0.0003 - accuracy: 1.0000
#> Epoch 20 - loss: 0.0002 - accuracy: 1.0000
#> Epoch 21 - loss: 0.0002 - accuracy: 1.0000
#> Epoch 22 - loss: 0.0002 - accuracy: 1.0000
#> Epoch 23 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 24 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 25 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 26 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 27 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 28 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 29 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 30 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 31 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 32 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 33 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 34 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 35 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 36 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 37 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 38 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 39 - loss: 0.0001 - accuracy: 1.0000
#> Epoch 40 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 41 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 42 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 43 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 44 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 45 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 46 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 47 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 48 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 49 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 50 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 51 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 52 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 53 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 54 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 55 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 56 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 57 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 58 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 59 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 60 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 61 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 62 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 63 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 64 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 65 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 66 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 67 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 68 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 69 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 70 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 71 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 72 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 73 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 74 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 75 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 76 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 77 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 78 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 79 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 80 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 81 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 82 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 83 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 84 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 85 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 86 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 87 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 88 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 89 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 90 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 91 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 92 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 93 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 94 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 95 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 96 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 97 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 98 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 99 - loss: 0.0000 - accuracy: 1.0000
#> Epoch 100 - loss: 0.0000 - accuracy: 1.0000
# }