diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index b391f15daf..1e32579f92 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -16,32 +16,7 @@ _replace_with_custom_fn_if_matches_filter, quantize_, ) - - -class ToyLinearModel(torch.nn.Module): - """Single linear for m * k * n problem size""" - - def __init__( - self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda" - ): - super().__init__() - self.m = m - self.dtype = dtype - self.device = device - self.linear = torch.nn.Linear(k, n, bias=has_bias).to( - dtype=self.dtype, device=self.device - ) - - def example_inputs(self): - return ( - torch.randn( - self.m, self.linear.in_features, dtype=self.dtype, device=self.device - ), - ) - - def forward(self, x): - x = self.linear(x) - return x +from torchao.testing.model_architectures import ToySingleLinearModel def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass): @@ -70,8 +45,8 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): @torch.no_grad def _bench_quantized_tensor_subclass_perf(api, config, M, N, K): - m = ToyLinearModel( - M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda" + m = ToySingleLinearModel( + K, N, dtype=torch.bfloat16, device="cuda", has_bias=True ).eval() m_bf16 = copy.deepcopy(m) example_inputs = m.example_inputs() diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 738e9b6164..5decc9de0e 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -38,6 +38,7 @@ choose_qparams_affine, ) from torchao.quantization.quantize_.common import KernelPreference +from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.utils import ( is_sm_at_least_89, is_sm_at_least_90, @@ -48,18 +49,6 @@ torch.manual_seed(0) -class ToyLinearModel(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) - self.linear2 = torch.nn.Linear(out_features, in_features, bias=False) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - - class TestAffineQuantizedFloat8Compile(InductorTestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( @@ -122,7 +111,7 @@ def test_fp8_linear_variants( } # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + model = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda") quantized_model = copy.deepcopy(model) factory = mode_map[mode]() @@ -179,7 +168,7 @@ def test_per_row_with_float32(self): AssertionError, match="PerRow quantization only works for bfloat16 precision", ): - model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda") + model = ToyTwoLinearModel(64, 64, 64).eval().to(torch.float32).to("cuda") quantize_( model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), @@ -192,7 +181,7 @@ def test_per_row_with_float32(self): @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) def test_serialization(self, mode: str): # Create and quantize the model - model = ToyLinearModel(16, 32).to(device="cuda") + model = ToyTwoLinearModel(16, 32, 16).to(device="cuda") mode_map = { "dynamic": partial( @@ -224,7 +213,7 @@ def test_serialization(self, mode: str): # Create a new model and load the state dict with torch.device("meta"): - new_model = ToyLinearModel(16, 32) + new_model = ToyTwoLinearModel(16, 32, 16) if mode == "static": quantize_(new_model, factory) new_model.load_state_dict(loaded_state_dict, assign=True) @@ -266,7 +255,7 @@ def test_serialization(self, mode: str): ) def test_fp8_weight_dimension_warning(self): # Create model with incompatible dimensions (not multiples of 16) - model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights + model = ToyTwoLinearModel(10, 25, 10).cuda() # 10x25 and 25x10 weights # Set up logging capture with self.assertLogs(