Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 3 additions & 28 deletions benchmarks/benchmark_aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,7 @@
_replace_with_custom_fn_if_matches_filter,
quantize_,
)


class ToyLinearModel(torch.nn.Module):
"""Single linear for m * k * n problem size"""

def __init__(
self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda"
):
super().__init__()
self.m = m
self.dtype = dtype
self.device = device
self.linear = torch.nn.Linear(k, n, bias=has_bias).to(
dtype=self.dtype, device=self.device
)

def example_inputs(self):
return (
torch.randn(
self.m, self.linear.in_features, dtype=self.dtype, device=self.device
),
)

def forward(self, x):
x = self.linear(x)
return x
from torchao.testing.model_architectures import ToySingleLinearModel


def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
Expand Down Expand Up @@ -70,8 +45,8 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):

@torch.no_grad
def _bench_quantized_tensor_subclass_perf(api, config, M, N, K):
m = ToyLinearModel(
M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda"
m = ToySingleLinearModel(
K, N, dtype=torch.bfloat16, device="cuda", has_bias=True
).eval()
m_bf16 = copy.deepcopy(m)
example_inputs = m.example_inputs()
Expand Down
23 changes: 6 additions & 17 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
choose_qparams_affine,
)
from torchao.quantization.quantize_.common import KernelPreference
from torchao.testing.model_architectures import ToyTwoLinearModel
from torchao.utils import (
is_sm_at_least_89,
is_sm_at_least_90,
Expand All @@ -48,18 +49,6 @@
torch.manual_seed(0)


class ToyLinearModel(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
self.linear2 = torch.nn.Linear(out_features, in_features, bias=False)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


class TestAffineQuantizedFloat8Compile(InductorTestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
Expand Down Expand Up @@ -122,7 +111,7 @@ def test_fp8_linear_variants(
}

# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
model = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda")

quantized_model = copy.deepcopy(model)
factory = mode_map[mode]()
Expand Down Expand Up @@ -179,7 +168,7 @@ def test_per_row_with_float32(self):
AssertionError,
match="PerRow quantization only works for bfloat16 precision",
):
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
model = ToyTwoLinearModel(64, 64, 64).eval().to(torch.float32).to("cuda")
quantize_(
model,
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
Expand All @@ -192,7 +181,7 @@ def test_per_row_with_float32(self):
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
def test_serialization(self, mode: str):
# Create and quantize the model
model = ToyLinearModel(16, 32).to(device="cuda")
model = ToyTwoLinearModel(16, 32, 16).to(device="cuda")

mode_map = {
"dynamic": partial(
Expand Down Expand Up @@ -224,7 +213,7 @@ def test_serialization(self, mode: str):

# Create a new model and load the state dict
with torch.device("meta"):
new_model = ToyLinearModel(16, 32)
new_model = ToyTwoLinearModel(16, 32, 16)
if mode == "static":
quantize_(new_model, factory)
new_model.load_state_dict(loaded_state_dict, assign=True)
Expand Down Expand Up @@ -266,7 +255,7 @@ def test_serialization(self, mode: str):
)
def test_fp8_weight_dimension_warning(self):
# Create model with incompatible dimensions (not multiples of 16)
model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights
model = ToyTwoLinearModel(10, 25, 10).cuda() # 10x25 and 25x10 weights

# Set up logging capture
with self.assertLogs(
Expand Down
Loading