diff --git a/Cargo.lock b/Cargo.lock index 228bfcb77d..335afe4e6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2325,6 +2325,16 @@ dependencies = [ "tar", ] +[[package]] +name = "custom-learning-strategy" +version = "0.20.0" +dependencies = [ + "burn", + "derive-new", + "guide", + "log", +] + [[package]] name = "custom-renderer" version = "0.20.0" diff --git a/crates/burn-train/src/components.rs b/crates/burn-train/src/components.rs index e5793a469c..1a7950ef99 100644 --- a/crates/burn-train/src/components.rs +++ b/crates/burn-train/src/components.rs @@ -39,6 +39,7 @@ pub trait LearnerComponentTypes { > + Send; /// The checkpointer used for the scheduler. type CheckpointerLrScheduler: Checkpointer<::Record, Self::Backend>; + /// Processes events happening during training and valid. type EventProcessor: EventProcessorTraining< ItemTrain = ::TrainOutput, ItemValid = ::ValidOutput, diff --git a/crates/burn-train/src/learner/base.rs b/crates/burn-train/src/learner/base.rs index 360e2a60b8..c5686cf2c8 100644 --- a/crates/burn-train/src/learner/base.rs +++ b/crates/burn-train/src/learner/base.rs @@ -20,7 +20,7 @@ pub struct Learner { pub(crate) checkpoint: Option, pub(crate) grad_accumulation: Option, pub(crate) checkpointer: Option>, - pub(crate) learning_strategy: LearningStrategy, + pub(crate) learning_strategy: LearningStrategy, pub(crate) interrupter: Interrupter, pub(crate) early_stopping: Option, pub(crate) event_processor: LC::EventProcessor, @@ -32,7 +32,8 @@ pub struct Learner { pub(crate) type EarlyStoppingStrategyRef = Box; #[derive(new)] -pub(crate) struct LearnerCheckpointer { +/// Used to create, delete, or load checkpoints of the training process. +pub struct LearnerCheckpointer { model: LC::CheckpointerModel, optim: LC::CheckpointerOptimizer, lr_scheduler: LC::CheckpointerLrScheduler, @@ -40,7 +41,8 @@ pub(crate) struct LearnerCheckpointer { } impl LearnerCheckpointer { - pub(crate) fn checkpoint( + /// Create checkpoint for the training process. + pub fn checkpoint( &mut self, model: &LC::Model, optim: &LC::Optimizer, @@ -78,7 +80,8 @@ impl LearnerCheckpointer { } } - pub(crate) fn load_checkpoint( + /// Load a training checkpoint. + pub fn load_checkpoint( &self, model: LC::Model, optim: LC::Optimizer, diff --git a/crates/burn-train/src/learner/builder.rs b/crates/burn-train/src/learner/builder.rs index e283bb5f74..3a854dbfec 100644 --- a/crates/burn-train/src/learner/builder.rs +++ b/crates/burn-train/src/learner/builder.rs @@ -22,7 +22,7 @@ use crate::{ ApplicationLoggerInstaller, EarlyStoppingStrategyRef, FileApplicationLoggerInstaller, LearnerCheckpointer, LearnerSummaryConfig, LearningStrategy, TrainStep, ValidStep, }; -use burn_core::module::AutodiffModule; +use burn_core::module::{AutodiffModule, Module}; use burn_core::record::FileRecorder; use burn_core::tensor::backend::AutodiffBackend; use burn_optim::Optimizer; @@ -57,7 +57,6 @@ where checkpoint: Option, directory: PathBuf, grad_accumulation: Option, - learning_strategy: LearningStrategy, renderer: Option>, metrics: MetricsTraining, event_store: LogEventStore, @@ -71,6 +70,19 @@ where _p: PhantomData<(TI, VI, TO, VO)>, } +type LC = LearnerComponentsMarker< + B, + S, + M, + O, + AsyncCheckpointer<>::Record, B>, + AsyncCheckpointer<>::Record, B>, + AsyncCheckpointer<::Record, B>, + AsyncProcessorTraining>, + Box, + LearningDataMarker, +>; + impl LearnerBuilder where B: AutodiffBackend, @@ -97,7 +109,6 @@ where checkpointers: None, directory, grad_accumulation: None, - learning_strategy: LearningStrategy::default(), metrics: MetricsTraining::default(), event_store: LogEventStore::default(), renderer: None, @@ -232,12 +243,6 @@ where self } - /// Run the training loop with different strategies - pub fn learning_strategy(mut self, learning_strategy: LearningStrategy) -> Self { - self.learning_strategy = learning_strategy; - self - } - /// The epoch from which the training must resume. pub fn checkpoint(mut self, checkpoint: usize) -> Self { self.checkpoint = Some(checkpoint); @@ -320,20 +325,8 @@ where model: M, optim: O, lr_scheduler: S, - ) -> Learner< - LearnerComponentsMarker< - B, - S, - M, - O, - AsyncCheckpointer, - AsyncCheckpointer, - AsyncCheckpointer, B>, - AsyncProcessorTraining>, - Box, - LearningDataMarker, - >, - > + learning_strategy: LearningStrategy>, + ) -> Learner> where M::Record: 'static, O::Record: 'static, @@ -373,7 +366,7 @@ where None }; - let learning_strategy = Self::prepare_learning_strategy(self.learning_strategy); + let learning_strategy = Self::prepare_learning_strategy(learning_strategy); Learner { model, @@ -392,7 +385,15 @@ where } } - fn prepare_learning_strategy(learning_strategy: LearningStrategy) -> LearningStrategy { + #[allow(clippy::type_complexity)] + fn prepare_learning_strategy( + learning_strategy: LearningStrategy>, + ) -> LearningStrategy> + where + M::Record: 'static, + O::Record: 'static, + S::Record: 'static, + { if let LearningStrategy::MultiDeviceNaive(devices) = &learning_strategy && devices.len() == 1 { diff --git a/crates/burn-train/src/learner/strategies/base.rs b/crates/burn-train/src/learner/strategies/base.rs index 0e509c54de..8920a9cf40 100644 --- a/crates/burn-train/src/learner/strategies/base.rs +++ b/crates/burn-train/src/learner/strategies/base.rs @@ -2,7 +2,7 @@ use std::sync::Arc; #[cfg(feature = "ddp")] use burn_collective::CollectiveConfig; -use burn_core::{module::AutodiffModule, tensor::backend::AutodiffBackend}; +use burn_core::{module::AutodiffModule, prelude::Backend, tensor::backend::AutodiffBackend}; use crate::{ EarlyStoppingStrategyRef, Interrupter, Learner, LearnerCheckpointer, TrainLoader, @@ -12,23 +12,33 @@ use crate::{ processor::{EventProcessorTraining, LearnerEvent}, store::EventStoreClient, }, + multi::CustomMultiDeviceLearningStrategy, + single::CustomSingleDeviceLearningStrategy, }; +type LearnerDevice = <::Backend as Backend>::Device; + /// How should the learner run the learning for the model #[derive(Clone)] -pub enum LearningStrategy { +pub enum LearningStrategy { /// Training on one device - SingleDevice(B::Device), + SingleDevice(LearnerDevice), + + /// Training on one device with a custom learning strategy + CustomSingleDevice(CustomSingleDeviceLearningStrategy), /// Legacy implementation of local multi-device training - MultiDeviceNaive(Vec), + MultiDeviceNaive(Vec>), + + /// Training on multiple devices with a custom learning strategy. + CustomMultiDevice(CustomMultiDeviceLearningStrategy), /// Training with input distributed across devices, each device has its own copy of the model. /// Collective ops are used to sync the gradients after each pass. #[cfg(feature = "ddp")] DistributedDataParallel { /// Devices on this node for the DDP - devices: Vec, + devices: Vec>, /// The configuration for collective operations /// num_devices is ignored @@ -38,21 +48,21 @@ pub enum LearningStrategy { /// Constructor for a distributed data parallel (DDP) learning strategy #[cfg(feature = "ddp")] -pub fn ddp( - devices: Vec, +pub fn ddp( + devices: Vec>, config: CollectiveConfig, -) -> LearningStrategy { +) -> LearningStrategy { LearningStrategy::DistributedDataParallel { devices, config } } -impl Default for LearningStrategy { +impl Default for LearningStrategy { fn default() -> Self { Self::SingleDevice(Default::default()) } } /// Provides the `fit` function for any learning strategy -pub(crate) trait LearningMethod { +pub trait LearningMethod { /// The dataloaders after being prepared for this trainin strategy /// /// (eg: splitting for multiple devices) @@ -152,14 +162,23 @@ pub(crate) trait LearningMethod { /// Struct to minimise parameters passed to [LearningMethod::learn] /// These components are used during training -pub(crate) struct LearnerComponents { +pub struct LearnerComponents { + /// The [Optimizer](LearnerComponentTypes::Optimizer) used for the training. pub optim: LC::Optimizer, + /// The [learning rate scheduler](LearnerComponentTypes::LrScheduler) used for the training. pub lr_scheduler: LC::LrScheduler, + /// The number of epochs the training should last. pub num_epochs: usize, + /// Enables gradients accumulation. pub grad_accumulation: Option, + /// A [LearnerCheckpointer](LearnerCheckpointer) used to save and load training checkpoints. pub checkpointer: Option>, + /// An [Interupter](Interrupter) that allows aborting the training/evaluation process early. pub interrupter: Interrupter, + /// [Cloneable reference to an early stopping strategy](EarlyStoppingStrategyRef). pub early_stopping: Option, + /// An [EventProcessor](LearnerComponentTypes::EventProcessor) that processes events happening during training and validation. pub event_processor: LC::EventProcessor, + /// A reference to an [EventStoreClient](EventStoreClient). pub event_store: Arc, } diff --git a/crates/burn-train/src/learner/strategies/multi/method.rs b/crates/burn-train/src/learner/strategies/multi/method.rs index a1a8d3cf63..f56547a7d4 100644 --- a/crates/burn-train/src/learner/strategies/multi/method.rs +++ b/crates/burn-train/src/learner/strategies/multi/method.rs @@ -4,7 +4,7 @@ use crate::{ multi::epoch::MultiDeviceTrainEpoch, }; use burn_core::{data::dataloader::split::split_dataloader, module::Module, prelude::Backend}; -use std::marker::PhantomData; +use std::{marker::PhantomData, sync::Arc}; pub struct MultiDeviceLearningStrategy { devices: Vec<::Device>, @@ -19,6 +19,14 @@ impl MultiDeviceLearningStrategy { } } +pub type CustomMultiDeviceLearningStrategy = Arc< + dyn LearningMethod< + LC, + PreparedDataloaders = (Vec>, ValidLoader), + PreparedModel = ::Model, + >, +>; + impl LearningMethod for MultiDeviceLearningStrategy { type PreparedDataloaders = (Vec>, ValidLoader); diff --git a/crates/burn-train/src/learner/strategies/single/method.rs b/crates/burn-train/src/learner/strategies/single/method.rs index 846ad4ce19..de481164ce 100644 --- a/crates/burn-train/src/learner/strategies/single/method.rs +++ b/crates/burn-train/src/learner/strategies/single/method.rs @@ -4,7 +4,7 @@ use crate::{ learner::strategies::single::epoch::{SingleDeviceTrainEpoch, SingleDeviceValidEpoch}, }; use burn_core::{module::Module, tensor::Device}; -use std::marker::PhantomData; +use std::{marker::PhantomData, sync::Arc}; /// Simplest learning strategy possible, with only a single devices doing both the training and /// validation. @@ -21,6 +21,14 @@ impl SingleDeviceLearningStrategy { } } +pub type CustomSingleDeviceLearningStrategy = Arc< + dyn LearningMethod< + LC, + PreparedDataloaders = (TrainLoader, ValidLoader), + PreparedModel = ::Model, + >, +>; + impl LearningMethod for SingleDeviceLearningStrategy { type PreparedDataloaders = (TrainLoader, ValidLoader); diff --git a/crates/burn-train/src/learner/train_val.rs b/crates/burn-train/src/learner/train_val.rs index 6e7d2f7239..fd45a6a696 100644 --- a/crates/burn-train/src/learner/train_val.rs +++ b/crates/burn-train/src/learner/train_val.rs @@ -103,8 +103,10 @@ pub trait ValidStep { fn step(&self, item: VI) -> VO; } -pub(crate) type TrainLoader = Arc, InputTrain>>; -pub(crate) type ValidLoader = Arc, InputValid>>; +/// A reference to the training split [DataLoader](DataLoader). +pub type TrainLoader = Arc, InputTrain>>; +/// A reference to the validation split [DataLoader](DataLoader). +pub type ValidLoader = Arc, InputValid>>; /// The result of a training, containing the model along with the [renderer](MetricsRenderer). pub struct TrainingResult { @@ -137,10 +139,16 @@ impl Learner { let single_device = SingleDeviceLearningStrategy::new(device.clone()); single_device.fit(self, dataloader_train, dataloader_valid) } + LearningStrategy::CustomSingleDevice(learning_strategy) => learning_strategy + .clone() + .fit(self, dataloader_train, dataloader_valid), LearningStrategy::MultiDeviceNaive(devices) => { let multi_device = MultiDeviceLearningStrategy::new(devices.clone()); multi_device.fit(self, dataloader_train, dataloader_valid) } + LearningStrategy::CustomMultiDevice(learning_strategy) => learning_strategy + .clone() + .fit(self, dataloader_train, dataloader_valid), #[cfg(feature = "ddp")] LearningStrategy::DistributedDataParallel { devices, config } => { diff --git a/crates/burn-train/src/lib.rs b/crates/burn-train/src/lib.rs index 59548d02b5..d5313c8c07 100644 --- a/crates/burn-train/src/lib.rs +++ b/crates/burn-train/src/lib.rs @@ -20,6 +20,8 @@ pub mod logger; /// The metric module. pub mod metric; +pub use metric::processor::*; + mod learner; pub use learner::*; @@ -28,6 +30,8 @@ mod evaluator; pub use evaluator::*; +pub use components::LearnerComponentTypes; + #[cfg(test)] pub(crate) type TestBackend = burn_ndarray::NdArray; diff --git a/crates/burn-train/src/metric/processor/async_wrapper.rs b/crates/burn-train/src/metric/processor/async_wrapper.rs index 3fbc5eeaf2..debd4ce6c1 100644 --- a/crates/burn-train/src/metric/processor/async_wrapper.rs +++ b/crates/burn-train/src/metric/processor/async_wrapper.rs @@ -3,10 +3,12 @@ use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation}; use super::{EventProcessorTraining, LearnerEvent}; use async_channel::{Receiver, Sender}; +/// Event processor for the training process. pub struct AsyncProcessorTraining { sender: Sender>, } +/// Event processor for the model evaluation. pub struct AsyncProcessorEvaluation { sender: Sender>, } @@ -58,6 +60,7 @@ impl WorkerEvaluation

{ } impl AsyncProcessorTraining

{ + /// Create an event processor for training. pub fn new(processor: P) -> Self { let (sender, rec) = async_channel::bounded(1); @@ -68,6 +71,7 @@ impl AsyncProcessorTraining

{ } impl AsyncProcessorEvaluation

{ + /// Create an event processor for model evaluation. pub fn new(processor: P) -> Self { let (sender, rec) = async_channel::bounded(1); diff --git a/examples/custom-image-dataset/src/training.rs b/examples/custom-image-dataset/src/training.rs index 8fa1f4e0fd..c8d0305373 100644 --- a/examples/custom-image-dataset/src/training.rs +++ b/examples/custom-image-dataset/src/training.rs @@ -103,13 +103,13 @@ pub fn train(config: TrainingConfig, device: B::Device) { .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) .with_file_checkpointer(CompactRecorder::new()) - .learning_strategy(LearningStrategy::SingleDevice(device.clone())) .num_epochs(config.num_epochs) .summary() .build( Cnn::new(NUM_CLASSES.into(), &device), config.optimizer.init(), config.learning_rate, + LearningStrategy::SingleDevice(device.clone()), ); // Training diff --git a/examples/custom-learning-strategy/Cargo.toml b/examples/custom-learning-strategy/Cargo.toml new file mode 100644 index 0000000000..e7f1f5a342 --- /dev/null +++ b/examples/custom-learning-strategy/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "custom-learning-strategy" +edition.workspace = true +license.workspace = true +version.workspace = true +publish = false + +[lints] +workspace = true + +[dependencies] +burn = {path = "../../crates/burn", features=["autodiff", "webgpu", "vision"]} +guide = {path = "../guide"} +derive-new = { workspace = true } +log = { workspace = true } diff --git a/examples/custom-learning-strategy/examples/custom-learning-strategy.rs b/examples/custom-learning-strategy/examples/custom-learning-strategy.rs new file mode 100644 index 0000000000..cae7563a51 --- /dev/null +++ b/examples/custom-learning-strategy/examples/custom-learning-strategy.rs @@ -0,0 +1,5 @@ +use burn::backend::{Autodiff, WebGpu}; + +fn main() { + custom_learning_strategy::training::run::>(Default::default()); +} diff --git a/examples/custom-learning-strategy/src/lib.rs b/examples/custom-learning-strategy/src/lib.rs new file mode 100644 index 0000000000..9be447d875 --- /dev/null +++ b/examples/custom-learning-strategy/src/lib.rs @@ -0,0 +1,2 @@ +pub mod model; +pub mod training; diff --git a/examples/custom-learning-strategy/src/model.rs b/examples/custom-learning-strategy/src/model.rs new file mode 100644 index 0000000000..c341f5fe6f --- /dev/null +++ b/examples/custom-learning-strategy/src/model.rs @@ -0,0 +1,99 @@ +use burn::{ + nn::{ + Dropout, DropoutConfig, Linear, LinearConfig, Relu, + conv::{Conv2d, Conv2dConfig}, + loss::CrossEntropyLossConfig, + pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, + }, + prelude::*, + tensor::backend::AutodiffBackend, + train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, +}; +use guide::data::MnistBatch; + +#[derive(Module, Debug)] +pub struct Model { + conv1: Conv2d, + conv2: Conv2d, + pool: AdaptiveAvgPool2d, + dropout: Dropout, + linear1: Linear, + linear2: Linear, + activation: Relu, +} + +#[derive(Config, Debug)] +pub struct ModelConfig { + num_classes: usize, + hidden_size: usize, + #[config(default = "0.5")] + dropout: f64, +} + +impl ModelConfig { + /// Returns the initialized model. + pub fn init(&self, device: &B::Device) -> Model { + Model { + conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device), + conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device), + pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(), + activation: Relu::new(), + linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device), + linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device), + dropout: DropoutConfig::new(self.dropout).init(), + } + } +} + +impl Model { + /// # Shapes + /// - Images [batch_size, height, width] + /// - Output [batch_size, class_prob] + pub fn forward(&self, images: Tensor) -> Tensor { + let [batch_size, height, width] = images.dims(); + + // Create a channel. + let x = images.reshape([batch_size, 1, height, width]); + + let x = self.conv1.forward(x); // [batch_size, 8, _, _] + let x = self.dropout.forward(x); + let x = self.conv2.forward(x); // [batch_size, 16, _, _] + let x = self.dropout.forward(x); + let x = self.activation.forward(x); + + let x = self.pool.forward(x); // [batch_size, 16, 8, 8] + let x = x.reshape([batch_size, 16 * 8 * 8]); + let x = self.linear1.forward(x); + let x = self.dropout.forward(x); + let x = self.activation.forward(x); + + self.linear2.forward(x) // [batch_size, num_classes] + } + + pub fn forward_classification(&self, item: MnistBatch) -> ClassificationOutput { + let targets = item.targets; + let output = self.forward(item.images); + let loss = CrossEntropyLossConfig::new() + .init(&output.device()) + .forward(output.clone(), targets.clone()); + + ClassificationOutput { + loss, + output, + targets, + } + } +} + +impl TrainStep, ClassificationOutput> for Model { + fn step(&self, item: MnistBatch) -> TrainOutput> { + let item = self.forward_classification(item); + TrainOutput::new(self, item.loss.backward(), item) + } +} + +impl ValidStep, ClassificationOutput> for Model { + fn step(&self, batch: MnistBatch) -> ClassificationOutput { + self.forward_classification(batch) + } +} diff --git a/examples/custom-learning-strategy/src/training.rs b/examples/custom-learning-strategy/src/training.rs new file mode 100644 index 0000000000..3216aa30f1 --- /dev/null +++ b/examples/custom-learning-strategy/src/training.rs @@ -0,0 +1,215 @@ +use burn::train::EventProcessorTraining; + +use std::{marker::PhantomData, sync::Arc}; + +use crate::model::ModelConfig; +use burn::{ + data::{ + dataloader::DataLoaderBuilder, + dataset::{transform::PartialDataset, vision::MnistDataset}, + }, + lr_scheduler::{ + LrScheduler, composed::ComposedLrSchedulerConfig, cosine::CosineAnnealingLrSchedulerConfig, + linear::LinearLrSchedulerConfig, + }, + module::AutodiffModule, + optim::AdamConfig, + prelude::*, + record::CompactRecorder, + tensor::{Device, backend::AutodiffBackend}, + train::{ + LearnerBuilder, LearnerComponentTypes, LearnerEvent, LearnerItem, LearningMethod, + MetricEarlyStoppingStrategy, StoppingCondition, TrainLoader, TrainStep, ValidLoader, + ValidStep, + metric::{ + AccuracyMetric, LearningRateMetric, LossMetric, + store::{Aggregate, Direction, Split}, + }, + }, +}; +use guide::data::MnistBatcher; + +static ARTIFACT_DIR: &str = "/tmp/burn-example-mnist"; + +#[derive(Config, Debug)] +pub struct MnistTrainingConfig { + #[config(default = 10)] + pub num_epochs: usize, + #[config(default = 64)] + pub batch_size: usize, + #[config(default = 4)] + pub num_workers: usize, + #[config(default = 42)] + pub seed: u64, + #[config(default = 1e-4)] + pub lr: f64, + pub model: ModelConfig, + pub optimizer: AdamConfig, +} + +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts before to get an accurate learner summary + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + +pub fn run(device: B::Device) { + create_artifact_dir(ARTIFACT_DIR); + // Config + let config_model = ModelConfig::new(10, 1024); + let config_optimizer = AdamConfig::new(); + let config = MnistTrainingConfig::new(config_model, config_optimizer); + + B::seed(&device, config.seed); + + let model = config.model.init::(&device); + + let dataset_train_original = Arc::new(MnistDataset::train()); + let dataset_train = PartialDataset::new(dataset_train_original.clone(), 0, 55_000); + let dataset_valid = PartialDataset::new(dataset_train_original.clone(), 55_000, 60_000); + + let dataloader_train = DataLoaderBuilder::new(MnistBatcher::default()) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(dataset_train); + let dataloader_valid = DataLoaderBuilder::new(MnistBatcher::default()) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(dataset_valid); + let lr_scheduler = ComposedLrSchedulerConfig::new() + .cosine(CosineAnnealingLrSchedulerConfig::new(1.0, 2000)) + // Warmup + .linear(LinearLrSchedulerConfig::new(1e-8, 1.0, 2000)) + .linear(LinearLrSchedulerConfig::new(1e-2, 1e-6, 10000)); + let early_stopping = MetricEarlyStoppingStrategy::new( + &LossMetric::::new(), + Aggregate::Mean, + Direction::Lowest, + Split::Valid, + StoppingCondition::NoImprovementSince { n_epochs: 5 }, + ); + + let learner = LearnerBuilder::new(ARTIFACT_DIR) + .metrics((AccuracyMetric::new(), LossMetric::new())) + .metric_train_numeric(LearningRateMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .early_stopping(early_stopping) + .num_epochs(config.num_epochs) + .summary() + .build( + model, + config.optimizer.init(), + lr_scheduler.init().unwrap(), + burn::train::LearningStrategy::CustomSingleDevice(Arc::new( + MyCustomLearningStrategy::new(device), + )), + ); + + learner.fit(dataloader_train, dataloader_valid); +} + +struct MyCustomLearningStrategy { + device: Device, + _p: PhantomData, +} + +impl MyCustomLearningStrategy { + pub fn new(device: Device) -> Self { + Self { + device, + _p: PhantomData, + } + } +} + +impl LearningMethod for MyCustomLearningStrategy { + type PreparedDataloaders = (TrainLoader, ValidLoader); + + type PreparedModel = ::Model; + + fn prepare_dataloaders( + &self, + dataloader_train: TrainLoader, + dataloader_valid: ValidLoader, + ) -> Self::PreparedDataloaders { + // The reference model is always on the first device provided. + let train = dataloader_train.to_device(&self.device); + let valid = dataloader_valid.to_device(&self.device); + + (train, valid) + } + + fn prepare_model(&self, model: LC::Model) -> Self::PreparedModel { + model.fork(&self.device) + } + + fn learn( + &self, + mut model: Self::PreparedModel, + (dataloader_train, dataloader_valid): Self::PreparedDataloaders, + starting_epoch: usize, + components: burn::train::LearnerComponents, + ) -> (LC::Model, LC::EventProcessor) { + let mut scheduler = components.lr_scheduler; + let mut optim = components.optim; + let mut processor = components.event_processor; + + for epoch in starting_epoch..components.num_epochs + 1 { + // Iterate over our training and validation loop for X epochs. + log::info!("Executing training step for epoch {}", epoch,); + + // Single device / dataloader + let mut iterator = dataloader_train.iter(); + let mut iteration = 0; + + while let Some(item) = iterator.next() { + iteration += 1; + let lr = scheduler.step(); + log::info!("Iteration {iteration}"); + + let progress = iterator.progress(); + let item = model.step(item); + model = model.optimize(&mut optim, lr, item.grads); + + let item = LearnerItem::new( + item.item, + progress, + epoch, + components.num_epochs, + iteration, + Some(lr), + ); + + processor.process_train(LearnerEvent::ProcessedItem(item)); + } + processor.process_train(LearnerEvent::EndEpoch(epoch)); + + let model_valid = model.valid(); + + let mut iterator = dataloader_valid.iter(); + let mut iteration = 0; + + while let Some(item) = iterator.next() { + let progress = iterator.progress(); + iteration += 1; + + let item = model_valid.step(item); + let item = LearnerItem::new( + item, + progress, + epoch, + components.num_epochs, + iteration, + None, + ); + + processor.process_valid(LearnerEvent::ProcessedItem(item)); + } + processor.process_valid(LearnerEvent::EndEpoch(epoch)); + } + + (model, processor) + } +} diff --git a/examples/custom-renderer/src/lib.rs b/examples/custom-renderer/src/lib.rs index 23326646be..7c9e73adb0 100644 --- a/examples/custom-renderer/src/lib.rs +++ b/examples/custom-renderer/src/lib.rs @@ -89,14 +89,18 @@ pub fn run(device: B::Device) { // artifact dir does not need to be provided when log_to_file is false let builder = LearnerBuilder::new("") - .learning_strategy(LearningStrategy::SingleDevice(device.clone())) .num_epochs(config.num_epochs) .renderer(CustomRenderer {}) .with_application_logger(None); // can be used to interrupt training let _interrupter = builder.interrupter(); - let learner = builder.build(model, optim, config.lr); + let learner = builder.build( + model, + optim, + config.lr, + LearningStrategy::SingleDevice(device.clone()), + ); let _model_trained = learner.fit(dataloader_train, dataloader_test); } diff --git a/examples/guide/src/training.rs b/examples/guide/src/training.rs index 7dcb0a68d4..d766900989 100644 --- a/examples/guide/src/training.rs +++ b/examples/guide/src/training.rs @@ -94,13 +94,13 @@ pub fn train(artifact_dir: &str, config: TrainingConfig, dev .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) .with_file_checkpointer(CompactRecorder::new()) - .learning_strategy(LearningStrategy::SingleDevice(device.clone())) .num_epochs(config.num_epochs) .summary() .build( config.model.init::(&device), config.optimizer.init(), config.learning_rate, + LearningStrategy::SingleDevice(device.clone()), ); let result = learner.fit(dataloader_train, dataloader_test); diff --git a/examples/mnist/src/training.rs b/examples/mnist/src/training.rs index 6d887a54db..301e191dfd 100644 --- a/examples/mnist/src/training.rs +++ b/examples/mnist/src/training.rs @@ -5,7 +5,6 @@ use crate::{ model::Model, }; -use burn::optim::AdamWConfig; use burn::{ data::{ dataloader::DataLoaderBuilder, @@ -31,6 +30,7 @@ use burn::{ renderer::MetricsRenderer, }, }; +use burn::{optim::AdamWConfig, train::LearningStrategy}; static ARTIFACT_DIR: &str = "/tmp/burn-example-mnist"; @@ -107,8 +107,12 @@ pub fn run(device: B::Device) { )) .num_epochs(config.num_epochs) .summary() - .learning_strategy(burn::train::LearningStrategy::SingleDevice(device)) - .build(model, config.optimizer.init(), lr_scheduler.init().unwrap()); + .build( + model, + config.optimizer.init(), + lr_scheduler.init().unwrap(), + LearningStrategy::SingleDevice(device), + ); let result = learner.fit(dataloader_train, dataloader_valid); diff --git a/examples/simple-regression/src/training.rs b/examples/simple-regression/src/training.rs index 8297cb79a8..1d348ac4ca 100644 --- a/examples/simple-regression/src/training.rs +++ b/examples/simple-regression/src/training.rs @@ -70,10 +70,14 @@ pub fn run(artifact_dir: &str, device: B::Device) { .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) .with_file_checkpointer(CompactRecorder::new()) - .learning_strategy(LearningStrategy::SingleDevice(device.clone())) .num_epochs(config.num_epochs) .summary() - .build(model, config.optimizer.init(), 1e-3); + .build( + model, + config.optimizer.init(), + 1e-3, + LearningStrategy::SingleDevice(device.clone()), + ); let result = learner.fit(dataloader_train, dataloader_test); diff --git a/examples/text-classification/src/training.rs b/examples/text-classification/src/training.rs index 122a802f0a..1513afdee1 100644 --- a/examples/text-classification/src/training.rs +++ b/examples/text-classification/src/training.rs @@ -37,7 +37,7 @@ pub struct ExperimentConfig { pub optimizer: AdamConfig, #[config(default = "SeqLengthOption::Fixed(256)")] pub seq_length: SeqLengthOption, - #[config(default = 32)] + #[config(default = 8)] pub batch_size: usize, #[config(default = 5)] pub num_epochs: usize, @@ -98,10 +98,14 @@ pub fn train( .metric_valid_numeric(AccuracyMetric::new()) .metric_train_numeric(LearningRateMetric::new()) .with_file_checkpointer(CompactRecorder::new()) - .learning_strategy(LearningStrategy::MultiDeviceNaive(devices)) .num_epochs(config.num_epochs) .summary() - .build(model, optim, lr_scheduler); + .build( + model, + optim, + lr_scheduler, + LearningStrategy::MultiDeviceNaive(devices), + ); #[cfg(feature = "ddp")] let collective_config = diff --git a/examples/text-generation/src/training.rs b/examples/text-generation/src/training.rs index ce5b4558fc..51f58fd932 100644 --- a/examples/text-generation/src/training.rs +++ b/examples/text-generation/src/training.rs @@ -79,11 +79,15 @@ pub fn train + 'static>( .metric_valid(LossMetric::new()) .metric_train_numeric(LearningRateMetric::new()) .with_file_checkpointer(CompactRecorder::new()) - .learning_strategy(LearningStrategy::SingleDevice(device)) .grads_accumulation(accum) .num_epochs(config.num_epochs) .summary() - .build(model, optim, lr_scheduler); + .build( + model, + optim, + lr_scheduler, + LearningStrategy::SingleDevice(device.clone()), + ); let result = learner.fit(dataloader_train, dataloader_test);