| Title: | Higher Level 'API' for 'torch' | 
| Version: | 0.5.0 | 
| Description: | A high level interface for 'torch' providing utilities to reduce the the amount of code needed for common tasks, abstract away torch details and make the same code work on both the 'CPU' and 'GPU'. It's flexible enough to support expressing a large range of models. It's heavily inspired by 'fastai' by Howard et al. (2020) <doi:10.48550/arXiv.2002.04688>, 'Keras' by Chollet et al. (2015) and 'PyTorch Lightning' by Falcon et al. (2019) <doi:10.5281/zenodo.3828935>. | 
| License: | MIT + file LICENSE | 
| URL: | https://mlverse.github.io/luz/, https://github.com/mlverse/luz | 
| Encoding: | UTF-8 | 
| RoxygenNote: | 7.3.2 | 
| Imports: | torch (≥ 0.11.9000), magrittr, zeallot, rlang (≥ 1.0.0), coro, glue, progress, R6, generics, purrr, fs, prettyunits, cli | 
| Suggests: | knitr, rmarkdown, testthat (≥ 3.0.0), covr, Metrics, withr, vdiffr, ggplot2 (≥ 3.0.0), dplyr, torchvision, tfevents (≥ 0.0.2), tidyr | 
| VignetteBuilder: | knitr | 
| Config/testthat/edition: | 3 | 
| Collate: | 'accelerator.R' 'as_dataloader.R' 'utils.R' 'callbacks.R' 'callbacks-amp.R' 'callbacks-interrupt.R' 'callbacks-mixup.R' 'callbacks-monitor-metrics.R' 'callbacks-profile.R' 'callbacks-resume.R' 'callbacks-tfevents.R' 'context.R' 'losses.R' 'lr-finder.R' 'metrics.R' 'metrics-auc.R' 'module-plot.R' 'module-print.R' 'module.R' 'reexports.R' 'serialization.R' | 
| BugReports: | https://github.com/mlverse/luz/issues | 
| NeedsCompilation: | no | 
| Packaged: | 2025-07-29 15:53:02 UTC; dfalbel | 
| Author: | Daniel Falbel [aut, cre, cph], Christophe Regouby [ctb], RStudio [cph] | 
| Maintainer: | Daniel Falbel <daniel@rstudio.com> | 
| Repository: | CRAN | 
| Date/Publication: | 2025-07-29 16:30:09 UTC | 
Pipe operator
Description
See magrittr::%>% for details.
Usage
lhs %>% rhs
Create an accelerator
Description
Create an accelerator
Usage
accelerator(
  device_placement = TRUE,
  cpu = FALSE,
  cuda_index = torch::cuda_current_device()
)
Arguments
| device_placement | (logical) whether the  | 
| cpu | (logical) whether the training procedure should run on the CPU. | 
| cuda_index | (integer) index of the CUDA device to use if multiple GPUs are available. Default: the result of torch::cuda_current_device(). | 
Creates a dataloader from its input
Description
as_dataloader is used internally by luz to convert input
data and valid_data as passed to fit.luz_module_generator() to a
torch::dataloader
Usage
as_dataloader(x, ...)
## S3 method for class 'dataset'
as_dataloader(x, ..., batch_size = 32)
## S3 method for class 'iterable_dataset'
as_dataloader(x, ..., batch_size = 32)
## S3 method for class 'list'
as_dataloader(x, ...)
## S3 method for class 'dataloader'
as_dataloader(x, ...)
## S3 method for class 'matrix'
as_dataloader(x, ...)
## S3 method for class 'numeric'
as_dataloader(x, ...)
## S3 method for class 'array'
as_dataloader(x, ...)
## S3 method for class 'torch_tensor'
as_dataloader(x, ...)
Arguments
| x | the input object. | 
| ... | Passed to  | 
| batch_size | (int, optional): how many samples per batch to load
(default:  | 
Details
as_dataloader methods should have sensible defaults for batch_size,
parallel workers, etc.
It allows users to quickly experiment with fit.luz_module_generator() by not requiring
to create a torch::dataset and a torch::dataloader in simple
experiments.
Methods (by class)
-  as_dataloader(dataset): Converts atorch::dataset()to atorch::dataloader().
-  as_dataloader(iterable_dataset): Converts atorch::iterable_dataset()into atorch::dataloader()
-  as_dataloader(list): Converts a list of tensors or arrays with the same size in the first dimension to atorch::dataloader()
-  as_dataloader(dataloader): Returns the same dataloader
-  as_dataloader(matrix): Converts the matrix to a dataloader
-  as_dataloader(numeric): Converts the numeric vector to a dataloader
-  as_dataloader(array): Converts the array to a dataloader
-  as_dataloader(torch_tensor): Converts the tensor to a dataloader
Overriding
You can implement your own as_dataloader S3 method if you want your data
structure to be automatically supported by luz's fit.luz_module_generator().
The method must satisfy the following conditions:
- The method should return a - torch::dataloader().
- The only required argument is - x. You have good default for all other arguments.
It's better to avoid implementing as_dataloader methods for common S3 classes
like data.frames. In this case, its better to assign a different class to
the inputs and implement as_dataloader for it.
Context object
Description
Context object storing information about the model training context. See also ctx.
Public fields
- buffers
- This is a list of buffers that callbacks can use to write temporary information into - ctx.
Active bindings
- records
- stores information about values logged with - self$log.
- device
- allows querying the current accelerator device 
- callbacks
- list of callbacks that will be called. 
- iter
- current iteration 
- batch
- the current batch data. a list with input data and targets. 
- input
- a shortcut for - ctx$batch[[1]]
- target
- a shortcut for - ctx$batch[[2]]
- min_epochs
- the minimum number of epochs that the model will run on. 
- max_epochs
- the maximum number of epochs that the model will run. 
- hparams
- a list of hyperparameters that were used to initialize - ctx$model.
- opt_hparams
- a list of hyperparameters used to initialize the - ctx$optimizers.
- train_data
- a dataloader that is used for training the model 
- valid_data
- a dataloader using during model validation 
- accelerator
- an - accelerator()used to move data, model and etc the the correct device.
- optimizers
- a named list of optimizers that will be used during model training. 
- verbose
- bool wether the process is in verbose mode or not. 
- handlers
- List of error handlers that can be used. See - rlang::try_fetch()for more info.
- epoch_handlers
- List of error handlers that can be used. See - rlang::try_fetch()for more info.
- training
- A bool indicating if the model is in training or validation mode. 
- model
- The model being trained. 
- pred
- Last predicted values. 
- opt
- Current optimizer. 
- opt_name
- Current optimizer name. 
- data
- Current dataloader in use. 
- loss_fn
- Loss function used to train the model 
- loss
- Last computed loss values. Detached from the graph. 
- loss_grad
- Last computed loss value, not detached, so you can do additional tranformation. 
- epoch
- Current epoch. 
- metrics
- List of metrics that are tracked by the process. 
- step_opt
- Defines how step is called for the optimizer. It must be a function taking an optimizer as argument. 
Methods
Public methods
Method new()
Initializes the context object with minimal necessary information.
Usage
context$new(verbose, accelerator, callbacks, training)
Arguments
- verbose
- Whether the context should be in verbose mode or not. 
- accelerator
- A luz - accelerator()that configures device placement and others.
- callbacks
- A list of callbacks used by the model. See - luz_callback().
- training
- A boolean that indicates if the context is in training mode or not. 
Method log()
Allows logging arbitrary information in the ctx.
Usage
context$log(what, set, value, index = NULL, append = TRUE)
Arguments
- what
- (string) What you are logging. 
- set
- (string) Usually 'train' or 'valid' indicating the set you want to log to. But can be arbitrary info. 
- value
- Arbitrary value to log. 
- index
- Index that this value should be logged. If - NULLthe value is added to the end of list, otherwise the index is used.
- append
- If - TRUEand a value in the corresponding index already exists, then value is appended to the current value. If- FALSEvalue is overwritten in favor of the new value.
Method log_metric()
Log a metric by its name and value. Metric values are indexed by epoch.
Usage
context$log_metric(name, value)
Arguments
- name
- name of the metric 
- value
- Arbitrary value to log. 
Method get_log()
Get a specific value from the log.
Usage
context$get_log(what, set, index = NULL)
Arguments
- what
- (string) What you are logging. 
- set
- (string) Usually 'train' or 'valid' indicating the set you want to log to. But can be arbitrary info. 
- index
- Index that this value should be logged. If - NULLthe value is added to the end of list, otherwise the index is used.
Method get_metrics()
Get all metric given an epoch and set.
Usage
context$get_metrics(set, epoch = NULL)
Arguments
- set
- (string) Usually 'train' or 'valid' indicating the set you want to log to. But can be arbitrary info. 
- epoch
- The epoch you want to extract metrics from. 
Method get_metric()
Get the value of a metric given its name, epoch and set.
Usage
context$get_metric(name, set, epoch = NULL)
Arguments
- name
- name of the metric 
- set
- (string) Usually 'train' or 'valid' indicating the set you want to log to. But can be arbitrary info. 
- epoch
- The epoch you want to extract metrics from. 
Method get_formatted_metrics()
Get formatted metrics values
Usage
context$get_formatted_metrics(set, epoch = NULL)
Arguments
- set
- (string) Usually 'train' or 'valid' indicating the set you want to log to. But can be arbitrary info. 
- epoch
- The epoch you want to extract metrics from. 
Method get_metrics_df()
Get a data.frame containing all metrics.
Usage
context$get_metrics_df()
Method set_verbose()
Allows setting the verbose attribute.
Usage
context$set_verbose(verbose = NULL)
Arguments
- verbose
- boolean. If - TRUEverbose mode is used. If- FALSEnon verbose. if- NULLwe use the result of- interactive().
Method clean()
Removes unnecessary information from the context object.
Usage
context$clean()
Method call_callbacks()
Call the selected callbacks. Where name is the callback types to call, eg
'on_epoch_begin'.
Usage
context$call_callbacks(name)
Arguments
- name
- name of the metric 
Method state_dict()
Returns a list containing minimal information from the context. Used to create the returned values.
Usage
context$state_dict()
Method unsafe_set_records()
Are you sure you know what you are doing?
Usage
context$unsafe_set_records(records)
Arguments
- records
- New set of records to be set. 
Method clone()
The objects of this class are cloneable with this method.
Usage
context$clone(deep = FALSE)
Arguments
- deep
- Whether to make a deep clone. 
Context object
Description
Context objects used in luz to share information between model methods, metrics and callbacks.
Details
The ctx object is used in luz to share information between the
training loop and callbacks, model methods, and metrics. The table below
describes information available in the ctx by default. Other callbacks
could potentially modify these attributes or add new ones.
| Attribute | Description | 
| verbose | The value ( TRUEorFALSE) attributed to theverboseargument infit. | 
| accelerator | Accelerator object used to query the correct device to place models, data, etc. It assumes the value passed to the acceleratorparameter infit. | 
| model | Initialized nn_moduleobject that will be trained during thefitprocedure. | 
| optimizers | A named list of optimizers used during training. | 
| data | The currently in-use dataloader. When training it’s ctx$train_data, when doing validation itsctx$valid_data. It can also be the prediction dataset when inpredict. | 
| train_data | Dataloader passed to the dataargument infit. Modified to yield data in the selected device. | 
| valid_data | Dataloader passed to the valid_dataargument infit. Modified to yield data in the selected device. | 
| min_epochs | Minimum number of epochs the model will be trained for. | 
| max_epochs | Maximum number of epochs the model will be trained for. | 
| epoch | Current training epoch. | 
| iter | Current training iteration. It’s reset every epoch and when going from training to validation. | 
| training | Whether the model is in training or validation mode. See also help("luz_callback_train_valid") | 
| callbacks | List of callbacks that will be called during the training procedure. It’s the union of the list passed to the callbacksparameter and the defaultcallbacks. | 
| step | Closure that will be used to do one stepof the model. It’s used for both training and validation. Takes no argument, but can access thectxobject. | 
| call_callbacks | Call callbacks by name. For example call_callbacks("on_train_begin")will call all callbacks that provide methods for this point. | 
| batch | Last batch obtained by the dataloader. A batch is a list()with 2 elements, one that is used asinputand the other astarget. | 
| input | First element of the last batch obtained by the current dataloader. | 
| target | Second element of the last batch obtained by the current dataloader. | 
| pred | Last predictions obtained by ctx$model$forward. Note: can be potentially modified by previously ran callbacks. Also note that this might not be available if you used a custom training step. | 
| loss_fn | The active loss function that will be minimized during training. | 
| loss | Last computed loss from the model. Note: this might not be available if you modified the training or validation step. | 
| opt | Current optimizer, ie. the optimizer that will be used to do the next stepto update parameters. | 
| opt_nm | Current optimizer name. By default it’s opt, but can change if your model uses more than one optimizer depending on the set of parameters being optimized. | 
| metrics | list()with current metric objects that areupdated at everyon_train_batch_end()oron_valid_batch_end(). See alsohelp("luz_callback_metrics") | 
| records | list()recording metric values for training and validation for each epoch. See alsohelp("luz_callback_metrics"). Also records profiling metrics. Seehelp("luz_callback_profile")for more information. | 
| handlers | A named list()of handlers that is passed torlang::with_handlers()during the training loop and can be used to handle errors or conditions that might be raised by other callbacks. | 
| epoch_handlers | A named list of handlers that is used with rlang::with_handlers(). Those handlers are used inside the epochs loop, thus you can handle epoch specific conditions, that won’t necessarily end training. | 
Context attributes
See Also
Context object: context
Evaluates a fitted model on a dataset
Description
Evaluates a fitted model on a dataset
Usage
evaluate(
  object,
  data,
  ...,
  metrics = NULL,
  callbacks = list(),
  accelerator = NULL,
  verbose = NULL,
  dataloader_options = NULL
)
Arguments
| object | A fitted model to evaluate. | 
| data | (dataloader, dataset or list) A dataloader created with
 | 
| ... | Currently unused. | 
| metrics | A list of luz metrics to be tracked during evaluation. If  | 
| callbacks | (list, optional) A list of callbacks defined with
 | 
| accelerator | (accelerator, optional) An optional  | 
| verbose | (logical, optional) An optional boolean value indicating if
the fitting procedure should emit output to the console during training.
By default, it will produce output if  | 
| dataloader_options | Options used when creating a dataloader. See
 | 
Details
Once a model has been trained you might want to evaluate its performance
on a different dataset. For that reason, luz provides the ?evaluate
function that takes a fitted model and a dataset and computes the
metrics attached to the model.
Evaluate returns a luz_module_evaluation object that you can query for
metrics using the get_metrics function or simply print to see the
results.
For example:
evaluation <- fitted %>% evaluate(data = valid_dl) metrics <- get_metrics(evaluation) print(evaluation)
See Also
Other training: 
fit.luz_module_generator(),
predict.luz_module_fitted(),
setup()
Fit a nn_module
Description
Fit a nn_module
Usage
## S3 method for class 'luz_module_generator'
fit(
  object,
  data,
  epochs = 10,
  callbacks = NULL,
  valid_data = NULL,
  accelerator = NULL,
  verbose = NULL,
  ...,
  dataloader_options = NULL
)
Arguments
| object | An  | 
| data | (dataloader, dataset or list) A dataloader created with
 | 
| epochs | (int) The maximum number of epochs for training the model. If a
single value is provided, this is taken to be the  | 
| callbacks | (list, optional) A list of callbacks defined with
 | 
| valid_data | (dataloader, dataset, list or scalar value; optional) A
dataloader created with  | 
| accelerator | (accelerator, optional) An optional  | 
| verbose | (logical, optional) An optional boolean value indicating if
the fitting procedure should emit output to the console during training.
By default, it will produce output if  | 
| ... | Currently unused. | 
| dataloader_options | Options used when creating a dataloader. See
 | 
Value
A fitted object that can be saved with luz_save() and can be
printed with print() and plotted with plot().
See Also
predict.luz_module_fitted() for how to create predictions.
setup() to find out how to create modules that can be trained with fit.
Other training: 
evaluate(),
predict.luz_module_fitted(),
setup()
Get metrics from the object
Description
Get metrics from the object
Usage
get_metrics(object, ...)
## S3 method for class 'luz_module_fitted'
get_metrics(object, ...)
Arguments
| object | The object to query for metrics. | 
| ... | Currently unused. | 
Value
A data.frame containing the metric values.
Methods (by class)
-  get_metrics(luz_module_fitted): Extract metrics from a luz fitted model.
Learning Rate Finder
Description
Learning Rate Finder
Usage
lr_finder(
  object,
  data,
  steps = 100,
  start_lr = 1e-07,
  end_lr = 0.1,
  log_spaced_intervals = TRUE,
  ...,
  verbose = NULL
)
Arguments
| object | An nn_module that has been setup(). | 
| data | (dataloader) A dataloader created with torch::dataloader() used for learning rate finding. | 
| steps | (integer) The number of steps to iterate over in the learning rate finder. Default: 100. | 
| start_lr | (float) The smallest learning rate. Default: 1e-7. | 
| end_lr | (float) The highest learning rate. Default: 1e-1. | 
| log_spaced_intervals | (logical) Whether to divide the range between start_lr and end_lr into log-spaced intervals (alternative: uniform intervals). Default: TRUE | 
| ... | Other arguments passed to  | 
| verbose | Wether to show a progress bar during the process. | 
Value
A dataframe with two columns: learning rate and loss
Examples
if (torch::torch_is_installed()) {
library(torch)
ds <- torch::tensor_dataset(x = torch_randn(100, 10), y = torch_randn(100, 1))
dl <- torch::dataloader(ds, batch_size = 32)
model <- torch::nn_linear
model <- model %>% setup(
  loss = torch::nn_mse_loss(),
  optimizer = torch::optim_adam
) %>%
  set_hparams(in_features = 10, out_features = 1)
records <- lr_finder(model, dl, verbose = FALSE)
plot(records)
}
Create a new callback
Description
Create a new callback
Usage
luz_callback(
  name = NULL,
  ...,
  private = NULL,
  active = NULL,
  parent_env = parent.frame(),
  inherit = NULL
)
Arguments
| name | name of the callback | 
| ... | Public methods of the callback. The name of the methods is used to know how they should be called. See the details section. | 
| private | An optional list of private members, which can be functions and non-functions. | 
| active | An optional list of active binding functions. | 
| parent_env | An environment to use as the parent of newly-created objects. | 
| inherit | A R6ClassGenerator object to inherit from; in other words, a
superclass. This is captured as an unevaluated expression which is
evaluated in  | 
Details
Let’s implement a callback that prints ‘Iteration n’ (where n is the
iteration number) for every batch in the training set and ‘Done’ when an
epoch is finished. For that task we use the luz_callback function:
print_callback <- luz_callback(
  name = "print_callback",
  initialize = function(message) {
    self$message <- message
  },
  on_train_batch_end = function() {
    cat("Iteration ", ctx$iter, "\n")
  },
  on_epoch_end = function() {
    cat(self$message, "\n")
  }
)
luz_callback() takes named functions as ... arguments, where the
name indicates the moment at which the callback should be called. For
instance on_train_batch_end() is called for every batch at the end of
the training procedure, and on_epoch_end() is called at the end of
every epoch.
The returned value of luz_callback() is a function that initializes an
instance of the callback. Callbacks can have initialization parameters,
like the name of a file where you want to log the results. In that case,
you can pass an initialize method when creating the callback
definition, and save these parameters to the self object. In the above
example, the callback has a message parameter that is printed at the
end of each epoch.
Once a callback is defined it can be passed to the fit function via
the callbacks parameter:
fitted <- net %>%
  setup(...) %>%
  fit(..., callbacks = list(
    print_callback(message = "Done!")
  ))
Callbacks can be called in many different positions of the training loop, including combinations of them. Here’s an overview of possible callback breakpoints:
Start Fit
   - on_fit_begin
  Start Epoch Loop
     - on_epoch_begin
    Start Train
       - on_train_begin
      Start Batch Loop
         - on_train_batch_begin
          Start Default Training Step
            - on_train_batch_after_pred
            - on_train_batch_after_loss
            - on_train_batch_before_backward
            - on_train_batch_before_step
            - on_train_batch_after_step
          End Default Training Step:
         - on_train_batch_end
      End Batch Loop
       - on_train_end
    End Train
    Start Valid
       - on_valid_begin
      Start Batch Loop
         - on_valid_batch_begin
          Start Default Validation Step
            - on_valid_batch_after_pred
            - on_valid_batch_after_loss
          End Default Validation Step
         - on_valid_batch_end
      End Batch Loop
       - on_valid_end
    End Valid
      - on_epoch_end
  End Epoch Loop
   - on_fit_end
End Fit
Every step marked with on_* is a point in the training procedure that
is available for callbacks to be called.
The other important part of callbacks is the ctx (context) object. See
help("ctx") for details.
By default, callbacks are called in the same order as they were passed
to fit (or predict or evaluate), but you can provide a weight
attribute that will control the order in which it will be called. For
example, if one callback has weight = 10 and another has weight = 1,
then the first one is called after the second one. Callbacks that don’t
specify a weight attribute are considered weight = 0. A few built-in
callbacks in luz already provide a weight value. For example, the
?luz_callback_early_stopping has a weight of Inf, since in general
we want to run it as the last thing in the loop.
Value
A luz_callback that can be passed to fit.luz_module_generator().
Prediction callbacks
You can also use callbacks when using predict(). In this case the supported
callback methods are detailed below:
Start predict - on_predict_begin Start prediction loop - on_predict_batch_begin - on_predict_batch_end End prediction loop - on_predict_end End predict
Evaluate callbacks
Callbacks can also be used with evaluate(), in this case, the callbacks that
are used are equivalent to those of the validation loop when using fit():
Start Valid - on_valid_begin Start Batch Loop - on_valid_batch_begin Start Default Validation Step - on_valid_batch_after_pred - on_valid_batch_after_loss End Default Validation Step - on_valid_batch_end End Batch Loop - on_valid_end End Valid
See Also
Other luz_callbacks: 
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Examples
print_callback <- luz_callback(
 name = "print_callback",
 on_train_batch_end = function() {
   cat("Iteration ", ctx$iter, "\n")
 },
 on_epoch_end = function() {
   cat("Done!\n")
 }
)
Resume training callback
Description
This callback allows you to resume training a model.
Usage
luz_callback_auto_resume(path = "./state.pt")
Arguments
| path | Path to save state files for the model. | 
Details
When using it, model weights, optimizer state are serialized at the end of each epoch. If something fails during training simply re-running the same script will restart the model training from the epoch right after the last epoch that was serialized.
Customizing serialization
By default model, optimizer state and records are serialized. Callbacks can
be used to customize serialization by implementing the state_dict() and
load_state_dict() methods.
If those methods are implemented, then state_dict() is called at the end of
each epoch and load_state_dict() is called when the model is resumed.
Note
In general you will want to add this callback as the last in the callbacks
list, this way, the serialized state is likely to contain all possible changes
that other callbacks could have made at 'on_epoch_end'. The default weight
attribute of this callback is Inf.
Read the checkpointing article in the pkgdown website for more information.
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Examples
if (torch::torch_is_installed()) {
library(torch)
library(luz)
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
  setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
  set_hparams(in_features = 10, out_features = 1) %>%
  set_opt_hparams(lr = 0.01)
# simulate a failure in the middle of epoch 5 happening only once.
callback_stop <- luz_callback(
  "interrupt",
  failed = FALSE,
  on_epoch_end = function() {
    if (ctx$epoch == 5 && !self$failed) {
      self$failed <- TRUE
      stop("Error on epoch 5")
    }
  }
)
path <- tempfile()
autoresume <- luz_callback_auto_resume(path = path)
interrupt <- callback_stop()
# try once and the model fails
try({
  results <- model %>% fit(
    list(x, y),
    callbacks = list(autoresume, interrupt),
    verbose = FALSE
  )
})
# model resumes and completes
results <- model %>% fit(
  list(x, y),
  callbacks = list(autoresume, interrupt),
  verbose = FALSE
)
get_metrics(results)
}
CSV logger callback
Description
Logs metrics obtained during training a file on disk. The file will have 1 line for each epoch/validation.
Usage
luz_callback_csv_logger(path)
Arguments
| path | path to a file on disk. | 
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Early stopping callback
Description
Stops training when a monitored metric stops improving
Usage
luz_callback_early_stopping(
  monitor = "valid_loss",
  min_delta = 0,
  patience = 0,
  mode = "min",
  baseline = NULL
)
Arguments
| monitor | A string in the format  | 
| min_delta | Minimum improvement to reset the patience counter. | 
| patience | Number of epochs without improving until stoping training. | 
| mode | Specifies the direction that is considered an improvement. By default 'min' is used. Can also be 'max' (higher is better) and 'zero' (closer to zero is better). | 
| baseline | An initial value that will be used as the best seen value
in the begining. Model will stop training if no better than baseline value
is found in the first  | 
Value
A luz_callback that does early stopping.
Note
This callback adds a on_early_stopping callback that can be used to
call callbacks as soon as the model stops training.
If verbose=TRUE in fit.luz_module_generator() a message is printed when
early stopping.
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Examples
cb <- luz_callback_early_stopping()
Gradient clipping callback
Description
By adding the GradientClip callback, the gradient norm_type (default:2) norm
is clipped to at most max_norm (default:1) using torch::nn_utils_clip_grad_norm_(),
which can avoid loss divergence.
Usage
luz_callback_gradient_clip(max_norm = 1, norm_type = 2)
Arguments
| max_norm | (float or int): max norm of the gradients | 
| norm_type | (float or int): type of the used p-norm. Can be  | 
References
See FastAI documentation for the GradientClip callback.
Interrupt callback
Description
Adds a handler that allows interrupting the training loop using ctrl + C.
Also registers a on_interrupt breakpoint so users can register callbacks to
be run on training loop interruption.
Usage
luz_callback_interrupt()
Value
A luz_callback
Note
In general you don't need to use these callback by yourself because it's always
included by default in fit.luz_module_generator().
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Examples
interrupt_callback <- luz_callback_interrupt()
Keep the best model
Description
Each epoch, if there's improvement in the monitored metric we serialize the model weights to a temp file. When training is done, we reload weights from the best model.
Usage
luz_callback_keep_best_model(
  monitor = "valid_loss",
  mode = "min",
  min_delta = 0
)
Arguments
| monitor | A string in the format  | 
| mode | Specifies the direction that is considered an improvement. By default 'min' is used. Can also be 'max' (higher is better) and 'zero' (closer to zero is better). | 
| min_delta | Minimum improvement to reset the patience counter. | 
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Examples
cb <- luz_callback_keep_best_model()
Learning rate scheduler callback
Description
Initializes and runs torch::lr_scheduler()s.
Usage
luz_callback_lr_scheduler(
  lr_scheduler,
  ...,
  call_on = "on_epoch_end",
  opt_name = NULL
)
Arguments
| lr_scheduler | A  | 
| ... | Additional arguments passed to  | 
| call_on | The callback breakpoint that  | 
| opt_name | name of the optimizer that will be affected by this callback.
Should match the name given in  | 
Value
A luz_callback() generator.
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Examples
if (torch::torch_is_installed()) {
cb <- luz_callback_lr_scheduler(torch::lr_step, step_size = 30)
}
Metrics callback
Description
Tracks metrics passed to setup() during training and validation.
Usage
luz_callback_metrics()
Details
This callback takes care of 2 ctx attributes:
-  ctx$metrics: stores the current metrics objects that are initialized once for epoch, and are furtherupdate()d andcompute()d every batch. You will rarely need to work with these metrics.
-  ctx$records$metrics: Stores metrics per training/validation and epoch. The structure is very similar toctx$losses.
Value
A luz_callback
Note
In general you won't need to explicitly use the metrics callback as it's
used by default in fit.luz_module_generator().
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Automatic Mixed Precision callback
Description
This callback will enable torch::local_autocast() training model forward
and during loss computation. It will then disable autocast and scale the loss
before backward() and opt$step(). See here
for more information.
Usage
luz_callback_mixed_precision(...)
Arguments
| ... | Passed to  | 
Value
A luz_callback
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Mixup callback
Description
Implementation of 'mixup: Beyond Empirical Risk Minimization'.
As of today, tested only for categorical data,
where targets are expected to be integers, not one-hot encoded vectors.
This callback is supposed to be used together with nn_mixup_loss().
Usage
luz_callback_mixup(alpha = 0.4, ..., run_valid = FALSE, auto_loss = FALSE)
Arguments
| alpha | parameter for the beta distribution used to sample mixing coefficients | 
| ... | currently unused. Just to force named arguments. | 
| run_valid | Should it run during validation | 
| auto_loss | Should it automatically modify the loss function? This will wrap
the loss function to create the mixup loss. If  | 
Details
Overall, we follow the fastai implementation described here. Namely,
- We work with a single dataloader only, randomly mixing two observations from the same batch. 
- We linearly combine losses computed for both targets: - loss(output, new_target) = weight * loss(output, target1) + (1-weight) * loss(output, target2)
- We draw different mixing coefficients for every pair. 
- We replace - weightwith- weight = max(weight, 1-weight)to avoid duplicates.
Value
A luz_callback
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Examples
if (torch::torch_is_installed()) {
mixup_callback <- luz_callback_mixup()
}
Checkpoints model weights
Description
This saves checkpoints of the model according to the specified metric and behavior.
Usage
luz_callback_model_checkpoint(
  path,
  monitor = "valid_loss",
  save_best_only = FALSE,
  mode = "min",
  min_delta = 0
)
Arguments
| path | Path to save the model on disk. The path is interpolated with  | 
| monitor | A string in the format  | 
| save_best_only | if  | 
| mode | Specifies the direction that is considered an improvement. By default 'min' is used. Can also be 'max' (higher is better) and 'zero' (closer to zero is better). | 
| min_delta | Minimum difference to consider as improvement. Only used when
 | 
Note
mode and min_delta are only used when save_best_only=TRUE.
save_best_only will overwrite the saved models if the path parameter
don't differentiate by epochs.
Read the checkpointing article in the pkgdown website for more information.
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Examples
luz_callback_model_checkpoint(path= "path/to/dir")
luz_callback_model_checkpoint(path= "path/to/dir/epoch-{epoch:02d}/model.pt")
luz_callback_model_checkpoint(path= "path/to/dir/epoch-{epoch:02d}/model-{monitor:.2f}.pt")
Profile callback
Description
Computes the times for high-level operations in the training loops.
Usage
luz_callback_profile()
Details
Records are saved in ctx$records$profile. Times are stored as seconds.
Data is stored in the following structure:
-  fit time for the entire fit procedure. 
-  epoch times per epoch 
Value
A luz_callback
Note
In general you don't need to use these callback by yourself because it's always
included by default in fit.luz_module_generator().
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Examples
profile_callback <- luz_callback_profile()
Progress callback
Description
Responsible for printing progress during training.
Usage
luz_callback_progress()
Value
A luz_callback
Note
In general you don't need to use these callback by yourself because it's always
included by default in fit.luz_module_generator().
Printing can be disabled by passing verbose=FALSE to fit.luz_module_generator().
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_resume_from_checkpoint(),
luz_callback_train_valid()
Allow resume model training from a specific checkpoint
Description
Allow resume model training from a specific checkpoint
Usage
luz_callback_resume_from_checkpoint(
  path,
  ...,
  restore_model_state = TRUE,
  restore_records = FALSE,
  restore_optimizer_state = FALSE,
  restore_callbacks_state = FALSE
)
Arguments
| path | Path to the checkpoint that you want to resume. | 
| ... | currently unused. | 
| restore_model_state | Wether to restore the model state from the checkpoint. | 
| restore_records | Wether to restore records from the checkpoint. | 
| restore_optimizer_state | Wether to restore the optimizer state from the checkpoint. | 
| restore_callbacks_state | Wether to restore the callbacks state from the checkpoint. | 
Note
Read the checkpointing article in the pkgdown website for more information.
See Also
luz_callback_model_checkpoint()
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_train_valid()
tfevents callback
Description
Logs metrics and other model information in the tfevents file format. Assuming tensorboard is installed, result can be visualized with
Usage
luz_callback_tfevents(logdir = "logs", histograms = FALSE, ...)
Arguments
| logdir | A directory to where log will be written to. | 
| histograms | A boolean specifying if histograms of model weights should
be logged. It can also be a character vector specifying the name of the parameters
that should be logged (names are the same as  | 
| ... | Currently not used. For future expansion. | 
Details
tensorboard --logdir=logs
Examples
if (torch::torch_is_installed()) {
library(torch)
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
  setup(loss = nnf_mse_loss, optimizer = optim_adam) %>%
  set_hparams(in_features = 10, out_features = 1) %>%
  set_opt_hparams(lr = 1e-4)
tmp <- tempfile()
model %>% fit(list(x, y), valid_data = 0.2, callbacks = list(
  luz_callback_tfevents(tmp, histograms = TRUE)
))
}
Train-eval callback
Description
Switches important flags for training and evaluation modes.
Usage
luz_callback_train_valid()
Details
It takes care of the three ctx attributes:
-  ctx$model: Responsible for callingctx$model$train()andctx$model$eval(), when appropriate.
-  ctx$training: Sets this flag toTRUEwhen training andFALSEwhen in validation mode.
-  ctx$loss: Resets thelossattribute tolist()when finished training/ or validating.
Value
A luz_callback
Note
In general you won't need to explicitly use the train_valid callback as it's
used by default in fit.luz_module_generator().
See Also
Other luz_callbacks: 
luz_callback(),
luz_callback_auto_resume(),
luz_callback_csv_logger(),
luz_callback_early_stopping(),
luz_callback_interrupt(),
luz_callback_keep_best_model(),
luz_callback_lr_scheduler(),
luz_callback_metrics(),
luz_callback_mixed_precision(),
luz_callback_mixup(),
luz_callback_model_checkpoint(),
luz_callback_profile(),
luz_callback_progress(),
luz_callback_resume_from_checkpoint()
Load trained model
Description
Loads a fitted model. See documentation in luz_save().
Usage
luz_load(path)
Arguments
| path | path in file system to the object. | 
See Also
Other luz_save: 
luz_save()
Loads a checkpoint
Description
Works with checkpoints created typically with luz_callback_model_checkpoint().
Usage
luz_load_checkpoint(obj, path, ...)
Arguments
| obj | Object to which we want to load the checkpoint. | 
| path | Path of the checkpoint on disk. | 
| ... | unused. Is there to allow future extensions. | 
Loads model weights into a fitted object.
Description
This can be useful when you have saved model checkpoints during training and want to reload the best checkpoint in the end.
Usage
luz_load_model_weights(obj, path, ...)
luz_save_model_weights(obj, path)
Arguments
| obj | luz object to which you want to copy the new weights. | 
| path | path to saved model in disk. | 
| ... | other arguments passed to  | 
Value
Returns NULL invisibly.
Warning
luz_save_model_weights operates inplace, ie modifies the model object to contain the
new weights.
Creates a new luz metric
Description
Creates a new luz metric
Usage
luz_metric(
  name = NULL,
  ...,
  private = NULL,
  active = NULL,
  parent_env = parent.frame(),
  inherit = NULL
)
Arguments
| name | string naming the new metric. | 
| ... | named list of public methods. You should implement at least
 | 
| private | An optional list of private members, which can be functions and non-functions. | 
| active | An optional list of active binding functions. | 
| parent_env | An environment to use as the parent of newly-created objects. | 
| inherit | A R6ClassGenerator object to inherit from; in other words, a
superclass. This is captured as an unevaluated expression which is
evaluated in  | 
Details
In order to implement a new luz_metric we need to implement 3 methods:
-  initialize: defines the metric initial state. This function is called for each epoch for both training and validation loops.
-  update: updates the metric internal state. This function is called at every training and validation step with the predictions obtained by the model and the target values obtained from the dataloader.
-  compute: uses the internal state to compute metric values. This function is called whenever we need to obtain the current metric value. Eg, it’s called every training step for metrics displayed in the progress bar, but only called once per epoch to record it’s value when the progress bar is not displayed.
Optionally, you can implement an abbrev field that gives the metric an
abbreviation that will be used when displaying metric information in the
console or tracking record. If no abbrev is passed, the class name
will be used.
Let’s take a look at the implementation of luz_metric_accuracy so you
can see how to implement a new one:
luz_metric_accuracy <- luz_metric(
  # An abbreviation to be shown in progress bars, or 
  # when printing progress
  abbrev = "Acc", 
  # Initial setup for the metric. Metrics are initialized
  # every epoch, for both training and validation
  initialize = function() {
    self$correct <- 0
    self$total <- 0
  },
  # Run at every training or validation step and updates
  # the internal state. The update function takes `preds`
  # and `target` as parameters.
  update = function(preds, target) {
    pred <- torch::torch_argmax(preds, dim = 2)
    self$correct <- self$correct + (pred == target)$
      to(dtype = torch::torch_float())$
      sum()$
      item()
    self$total <- self$total + pred$numel()
  },
  # Use the internal state to query the metric value
  compute = function() {
    self$correct/self$total
  }
)
Note: It’s good practice that the compute metric returns regular R
values instead of torch tensors and other parts of luz will expect that.
Value
Returns new luz metric.
See Also
Other luz_metrics: 
luz_metric_accuracy(),
luz_metric_binary_accuracy(),
luz_metric_binary_accuracy_with_logits(),
luz_metric_binary_auroc(),
luz_metric_mae(),
luz_metric_mse(),
luz_metric_multiclass_auroc(),
luz_metric_rmse()
Examples
luz_metric_accuracy <- luz_metric(
  # An abbreviation to be shown in progress bars, or
  # when printing progress
  abbrev = "Acc",
  # Initial setup for the metric. Metrics are initialized
  # every epoch, for both training and validation
  initialize = function() {
    self$correct <- 0
    self$total <- 0
  },
  # Run at every training or validation step and updates
  # the internal state. The update function takes `preds`
  # and `target` as parameters.
  update = function(preds, target) {
    pred <- torch::torch_argmax(preds, dim = 2)
    self$correct <- self$correct + (pred == target)$
      to(dtype = torch::torch_float())$
      sum()$
      item()
    self$total <- self$total + pred$numel()
  },
  # Use the internal state to query the metric value
  compute = function() {
    self$correct/self$total
  }
)
Accuracy
Description
Computes accuracy for multi-class classification problems.
Usage
luz_metric_accuracy()
Details
This metric expects to take logits or probabilities at every update. It will then take the columnwise argmax and compare to the target.
Value
Returns new luz metric.
See Also
Other luz_metrics: 
luz_metric(),
luz_metric_binary_accuracy(),
luz_metric_binary_accuracy_with_logits(),
luz_metric_binary_auroc(),
luz_metric_mae(),
luz_metric_mse(),
luz_metric_multiclass_auroc(),
luz_metric_rmse()
Examples
if (torch::torch_is_installed()) {
library(torch)
metric <- luz_metric_accuracy()
metric <- metric$new()
metric$update(torch_randn(100, 10), torch::torch_randint(1, 10, size = 100))
metric$compute()
}
Binary accuracy
Description
Computes the accuracy for binary classification problems where the
model returns probabilities. Commonly used when the loss is torch::nn_bce_loss().
Usage
luz_metric_binary_accuracy(threshold = 0.5)
Arguments
| threshold | value used to classifiy observations between 0 and 1. | 
Value
Returns new luz metric.
See Also
Other luz_metrics: 
luz_metric(),
luz_metric_accuracy(),
luz_metric_binary_accuracy_with_logits(),
luz_metric_binary_auroc(),
luz_metric_mae(),
luz_metric_mse(),
luz_metric_multiclass_auroc(),
luz_metric_rmse()
Examples
if (torch::torch_is_installed()) {
library(torch)
metric <- luz_metric_binary_accuracy(threshold = 0.5)
metric <- metric$new()
metric$update(torch_rand(100), torch::torch_randint(0, 1, size = 100))
metric$compute()
}
Binary accuracy with logits
Description
Computes accuracy for binary classification problems where the model
return logits. Commonly used together with torch::nn_bce_with_logits_loss().
Usage
luz_metric_binary_accuracy_with_logits(threshold = 0.5)
Arguments
| threshold | value used to classifiy observations between 0 and 1. | 
Details
Probabilities are generated using torch::nnf_sigmoid() and threshold is used to
classify between 0 or 1.
Value
Returns new luz metric.
See Also
Other luz_metrics: 
luz_metric(),
luz_metric_accuracy(),
luz_metric_binary_accuracy(),
luz_metric_binary_auroc(),
luz_metric_mae(),
luz_metric_mse(),
luz_metric_multiclass_auroc(),
luz_metric_rmse()
Examples
if (torch::torch_is_installed()) {
library(torch)
metric <- luz_metric_binary_accuracy_with_logits(threshold = 0.5)
metric <- metric$new()
metric$update(torch_randn(100), torch::torch_randint(0, 1, size = 100))
metric$compute()
}
Computes the area under the ROC
Description
To avoid storing all predictions and targets for an epoch we compute confusion matrices across a range of pre-established thresholds.
Usage
luz_metric_binary_auroc(
  num_thresholds = 200,
  thresholds = NULL,
  from_logits = FALSE
)
Arguments
| num_thresholds | Number of thresholds used to compute confusion matrices.
In that case, thresholds are created by getting  | 
| thresholds | (optional) If threshold are passed, then those are used to compute the
confusion matrices and  | 
| from_logits | Boolean indicating if predictions are logits, in that case we use sigmoid to put them in the unit interval. | 
See Also
Other luz_metrics: 
luz_metric(),
luz_metric_accuracy(),
luz_metric_binary_accuracy(),
luz_metric_binary_accuracy_with_logits(),
luz_metric_mae(),
luz_metric_mse(),
luz_metric_multiclass_auroc(),
luz_metric_rmse()
Examples
if (torch::torch_is_installed()){
library(torch)
actual <- c(1, 1, 1, 0, 0, 0)
predicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2)
y_true <- torch_tensor(actual)
y_pred <- torch_tensor(predicted)
m <- luz_metric_binary_auroc(thresholds = predicted)
m <- m$new()
m$update(y_pred[1:2], y_true[1:2])
m$update(y_pred[3:4], y_true[3:4])
m$update(y_pred[5:6], y_true[5:6])
m$compute()
}
Mean absolute error
Description
Computes the mean absolute error.
Usage
luz_metric_mae()
Value
Returns new luz metric.
See Also
Other luz_metrics: 
luz_metric(),
luz_metric_accuracy(),
luz_metric_binary_accuracy(),
luz_metric_binary_accuracy_with_logits(),
luz_metric_binary_auroc(),
luz_metric_mse(),
luz_metric_multiclass_auroc(),
luz_metric_rmse()
Examples
if (torch::torch_is_installed()) {
library(torch)
metric <- luz_metric_mae()
metric <- metric$new()
metric$update(torch_randn(100), torch_randn(100))
metric$compute()
}
Mean squared error
Description
Computes the mean squared error
Usage
luz_metric_mse()
Value
A luz_metric object.
See Also
Other luz_metrics: 
luz_metric(),
luz_metric_accuracy(),
luz_metric_binary_accuracy(),
luz_metric_binary_accuracy_with_logits(),
luz_metric_binary_auroc(),
luz_metric_mae(),
luz_metric_multiclass_auroc(),
luz_metric_rmse()
Computes the multi-class AUROC
Description
The same definition as Keras
is used by default. This is equivalent to the 'micro' method in SciKit Learn
too. See docs.
Usage
luz_metric_multiclass_auroc(
  num_thresholds = 200,
  thresholds = NULL,
  from_logits = FALSE,
  average = c("micro", "macro", "weighted", "none")
)
Arguments
| num_thresholds | Number of thresholds used to compute confusion matrices.
In that case, thresholds are created by getting  | 
| thresholds | (optional) If threshold are passed, then those are used to compute the
confusion matrices and  | 
| from_logits | If  | 
| average | The averaging method: 
 | 
Details
Note that class imbalance can affect this metric unlike the AUC for binary classification.
Currently the AUC is approximated using the 'interpolation' method described in Keras.
See Also
Other luz_metrics: 
luz_metric(),
luz_metric_accuracy(),
luz_metric_binary_accuracy(),
luz_metric_binary_accuracy_with_logits(),
luz_metric_binary_auroc(),
luz_metric_mae(),
luz_metric_mse(),
luz_metric_rmse()
Examples
if (torch::torch_is_installed()) {
library(torch)
actual <- c(1, 1, 1, 0, 0, 0) + 1L
predicted <- c(0.9, 0.8, 0.4, 0.5, 0.3, 0.2)
predicted <- cbind(1-predicted, predicted)
y_true <- torch_tensor(as.integer(actual))
y_pred <- torch_tensor(predicted)
m <- luz_metric_multiclass_auroc(thresholds = as.numeric(predicted),
                                 average = "micro")
m <- m$new()
m$update(y_pred[1:2,], y_true[1:2])
m$update(y_pred[3:4,], y_true[3:4])
m$update(y_pred[5:6,], y_true[5:6])
m$compute()
}
Root mean squared error
Description
Computes the root mean squared error.
Usage
luz_metric_rmse()
Value
Returns new luz metric.
See Also
Other luz_metrics: 
luz_metric(),
luz_metric_accuracy(),
luz_metric_binary_accuracy(),
luz_metric_binary_accuracy_with_logits(),
luz_metric_binary_auroc(),
luz_metric_mae(),
luz_metric_mse(),
luz_metric_multiclass_auroc()
Creates a metric set
Description
A metric set can be used to specify metrics that are only evaluated during training, validation or both.
Usage
luz_metric_set(metrics = NULL, train_metrics = NULL, valid_metrics = NULL)
Arguments
| metrics | A list of luz_metrics that are meant to be used in both training and validation. | 
| train_metrics | A list of luz_metrics that are only used during training. | 
| valid_metrics | A list of luz_metrics that are only sued for validation. | 
Saves luz objects to disk
Description
Allows saving luz fitted models to the disk. Objects can be loaded back with
luz_load().
Usage
luz_save(obj, path, ...)
Arguments
| obj | an object of class 'luz_module_fitted' as returned by
 | 
| path | path in file system to the object. | 
| ... | currently unused. | 
Warning
The ctx is naively serialized. Ie, we only use saveRDS() to serialize it.
Don't expect luz_save to work correctly if you have unserializable objects
in the ctx like torch_tensors and external pointers in general.
Note
Objects are saved as plain .rds files but obj$model is serialized
with torch_save before saving it.
See Also
Other luz_save: 
luz_load()
Loss to be used with callbacks_mixup().
Description
In the training phase, computes individual losses with regard to two targets, weights them item-wise, and averages the linear combinations to yield the mean batch loss. For validation and testing, defers to the passed-in loss.
Usage
nn_mixup_loss(loss)
Arguments
| loss | the underlying loss  | 
Details
It should be used together with luz_callback_mixup().
See Also
Mixup logic
Description
Logic underlying luz_callback_mixup().
Usage
nnf_mixup(x, y, weight)
Arguments
| x | an input batch | 
| y | a target batch | 
| weight | weighting coefficient to be used by  | 
Details
Based on the passed-in input and target batches, as well as applicable mixing weights, we return new tensors intended to replace the current batch. The new input batch is a weighted linear combination of input batch items, while the new target batch bundles the original targets, as well as the mixing weights, in a nested list.
Value
A list of:
-  x, the new, mixed-up input batch
-  y, alistof:-  ys, alistof:-  y1, the original targety1
-  y2, the mixed-in targety2
 
-  
-  weight, the mixing weights
 
-  
See Also
Examples
if (torch::torch_is_installed()) {
batch_x <- torch::torch_randn(c(10, 768))
batch_y <- torch::torch_randn(10)
weight <- torch::torch_tensor(rep(0.9, 10))$view(c(10, 1))
nnf_mixup(batch_x, batch_y, weight)
}
Create predictions for a fitted model
Description
Create predictions for a fitted model
Usage
## S3 method for class 'luz_module_fitted'
predict(
  object,
  newdata,
  ...,
  callbacks = list(),
  accelerator = NULL,
  verbose = NULL,
  dataloader_options = NULL
)
Arguments
| object | (fitted model) the fitted model object returned from  | 
| newdata | (dataloader, dataset, list or array) returning a list with at least 1 element. The other elements aren't used. | 
| ... | Currently unused. | 
| callbacks | (list, optional) A list of callbacks defined with
 | 
| accelerator | (accelerator, optional) An optional  | 
| verbose | (logical, optional) An optional boolean value indicating if
the fitting procedure should emit output to the console during training.
By default, it will produce output if  | 
| dataloader_options | Options used when creating a dataloader. See
 | 
See Also
Other training: 
evaluate(),
fit.luz_module_generator(),
setup()
Objects exported from other packages
Description
These objects are imported from other packages. Follow the links below to see their documentation.
- generics
Set hyper-parameter of a module
Description
This function is used to define hyper-parameters before calling fit for
luz_modules.
Usage
set_hparams(module, ...)
Arguments
| module | An  | 
| ... | The parameters set here will be used to initialize the  | 
Value
The same luz module
See Also
Other set_hparam: 
set_opt_hparams()
Set optimizer hyper-parameters
Description
This function is used to define hyper-parameters for the optimizer initialization method.
Usage
set_opt_hparams(module, ...)
Arguments
| module | An  | 
| ... | The parameters passed here will be used to initialize the optimizers.
For example, if your optimizer is  | 
Value
The same luz module
See Also
Other set_hparam: 
set_hparams()
Set's up a nn_module to use with luz
Description
The setup function is used to set important attributes and method for nn_modules
to be used with luz.
Usage
setup(module, loss = NULL, optimizer = NULL, metrics = NULL, backward = NULL)
Arguments
| module | ( | 
| loss | ( | 
| optimizer | ( | 
| metrics | ( | 
| backward | ( | 
Details
It makes sure the module have all the necessary ingredients in order to be fitted.
Value
A luz module that can be trained with fit().
Note
It also adds a device active field that can be used to query the current
module device within methods, with eg self$device. This is useful when
ctx() is not available, eg, when calling methods from outside the luz
wrappers. Users can override the default by implementing a device active
method in the input module.
See Also
Other training: 
evaluate(),
fit.luz_module_generator(),
predict.luz_module_fitted()