Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
ModelTestMetadata,
generate_test,
run_tests,
filter_tests_by_patterns,
)
from sqlmesh.core.user import User
from sqlmesh.utils import UniqueKeyDict, Verbosity
Expand Down Expand Up @@ -146,8 +147,8 @@
from typing_extensions import Literal

from sqlmesh.core.engine_adapter._typing import (
BigframeSession,
DF,
BigframeSession,
PySparkDataFrame,
PySparkSession,
SnowparkSession,
Expand Down Expand Up @@ -398,6 +399,8 @@ def __init__(
self._standalone_audits: UniqueKeyDict[str, StandaloneAudit] = UniqueKeyDict(
"standaloneaudits"
)
self._models_with_tests: t.Set[str] = set()
self._model_test_metadata: t.List[ModelTestMetadata] = []
self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros")
self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics")
self._jinja_macros = JinjaMacroRegistry()
Expand Down Expand Up @@ -636,6 +639,8 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
self._excluded_requirements.clear()
self._linters.clear()
self._environment_statements = []
self._models_with_tests.clear()
self._model_test_metadata.clear()

for loader, project in zip(self._loaders, loaded_projects):
self._jinja_macros = self._jinja_macros.merge(project.jinja_macros)
Expand All @@ -647,6 +652,8 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
self._requirements.update(project.requirements)
self._excluded_requirements.update(project.excluded_requirements)
self._environment_statements.extend(project.environment_statements)
self._models_with_tests.update(project.models_with_tests)
self._model_test_metadata.extend(project.model_test_metadata)

config = loader.config
self._linters[config.project] = Linter.from_rules(
Expand Down Expand Up @@ -1049,6 +1056,11 @@ def standalone_audits(self) -> MappingProxyType[str, StandaloneAudit]:
"""Returns all registered standalone audits in this context."""
return MappingProxyType(self._standalone_audits)

@property
def models_with_tests(self) -> t.Set[str]:
"""Returns all models with tests in this context."""
return self._models_with_tests

@property
def snapshots(self) -> t.Dict[str, Snapshot]:
"""Generates and returns snapshots based on models registered in this context.
Expand Down Expand Up @@ -2220,7 +2232,9 @@ def test(

pd.set_option("display.max_columns", None)

test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)
test_meta = self._filter_preloaded_tests(
test_meta=self._model_test_metadata, tests=tests, patterns=match_patterns
)

result = run_tests(
model_test_metadata=test_meta,
Expand Down Expand Up @@ -2782,6 +2796,33 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.")
return self.engine_adapter

def _filter_preloaded_tests(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What motivated this addition? Seems like it's something that doesn't need to be included in the scope of this PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, the test() method was setup to still lazily load the tests, so I had to implement this to ensure it maintained filtering functionality and utilized the eagerly loaded tests instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, makes sense. I suggest renaming this to _select_tests instead, then.

self,
test_meta: t.List[ModelTestMetadata],
tests: t.Optional[t.List[str]] = None,
patterns: t.Optional[t.List[str]] = None,
) -> t.List[ModelTestMetadata]:
"""Filter pre-loaded test metadata based on tests and patterns."""

if tests:
filtered_tests = []
for test in tests:
if "::" in test:
filename, test_name = test.split("::", maxsplit=1)
test_path = Path(filename)
filtered_tests.extend(
[t for t in test_meta if t.path == test_path and t.test_name == test_name]
)
else:
test_path = Path(test)
filtered_tests.extend([t for t in test_meta if t.path == test_path])
Comment on lines +2813 to +2818
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These iterations over test_meta look costly. Can we store a mapping test_name -> test_meta instead of keeping everything in a list and having to walk it every time?

test_meta = filtered_tests

if patterns:
test_meta = filter_tests_by_patterns(test_meta, patterns)

return test_meta

def _snapshots(
self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None
) -> t.Dict[str, Snapshot]:
Expand Down
15 changes: 15 additions & 0 deletions sqlmesh/core/linter/rules/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
return self.violation()


class NoMissingUnitTest(Rule):
"""All models must have a unit test found in the test/ directory yaml files"""

def check_model(self, model: Model) -> t.Optional[RuleViolation]:
# External models cannot have unit tests
if isinstance(model, ExternalModel):
return None

if model.name not in self.context.models_with_tests:
return self.violation(
violation_msg=f"Model {model.name} is missing unit test(s). Please add in the tests/ directory."
)
return None


class NoMissingExternalModels(Rule):
"""All external models must be registered in the external_models.yaml file"""

Expand Down
10 changes: 10 additions & 0 deletions sqlmesh/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class LoadedProject:
excluded_requirements: t.Set[str]
environment_statements: t.List[EnvironmentStatements]
user_rules: RuleSet
model_test_metadata: t.List[ModelTestMetadata]
models_with_tests: t.Set[str]
Comment on lines +67 to +68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an attribute of the LoadedProject? We just need the metadata right? Then you other consumers (e.g., the context, or linter) can extract this info and use it as needed.

Copy link
Author

@cmgoffena13 cmgoffena13 Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good catch, I can move this to the linter and keep the core code clean. Or actually, the context might be a little better, but I'm open to both

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure this is computed once and is kept in the context so that we don't recompute it on every model being linted. It's a global property.



class CacheBase(abc.ABC):
Expand Down Expand Up @@ -243,6 +245,12 @@ def load(self) -> LoadedProject:

user_rules = self._load_linting_rules()

model_test_metadata = self.load_model_tests()

models_with_tests = {
model_test_metadata.model_name for model_test_metadata in model_test_metadata
}

project = LoadedProject(
macros=macros,
jinja_macros=jinja_macros,
Expand All @@ -254,6 +262,8 @@ def load(self) -> LoadedProject:
excluded_requirements=excluded_requirements,
environment_statements=environment_statements,
user_rules=user_rules,
model_test_metadata=model_test_metadata,
models_with_tests=models_with_tests,
)
return project

Expand Down
4 changes: 4 additions & 0 deletions sqlmesh/core/test/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class ModelTestMetadata(PydanticModel):
def fully_qualified_test_name(self) -> str:
return f"{self.path}::{self.test_name}"

@property
def model_name(self) -> str:
return self.body["model"]
Comment on lines +23 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be a bit more conservative here with the lookup. IIRC the body is validated later at runtime, when the tests are actually ran, by which time it'll be too late if the test is missing its name and we don't wanna throw a KeyError.

Suggested change
@property
def model_name(self) -> str:
return self.body["model"]
@property
def model_name(self) -> str:
return self.body.get("model", "")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, quick fix, I'll get it done


def __hash__(self) -> int:
return self.fully_qualified_test_name.__hash__()

Expand Down
60 changes: 60 additions & 0 deletions tests/core/linter/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,63 @@ def test_no_missing_external_models_with_existing_file_not_ending_in_newline(
)
fix_path = sushi_path / "external_models.yaml"
assert edit.path == fix_path


def test_no_missing_unit_tests(tmp_path, copy_to_temp_path):
"""
Tests that the NoMissingUnitTest linter rule correctly identifies models
without corresponding unit tests in the tests/ directory

This test checks the sushi example project, enables the linter,
and verifies that the linter raises a rule violation for the models
that do not have a unit test
"""
sushi_paths = copy_to_temp_path("examples/sushi")
sushi_path = sushi_paths[0]

# Override the config.py to turn on lint
with open(sushi_path / "config.py", "r") as f:
read_file = f.read()

before = """ linter=LinterConfig(
enabled=False,
rules=[
"ambiguousorinvalidcolumn",
"invalidselectstarexpansion",
"noselectstar",
"nomissingaudits",
"nomissingowner",
"nomissingexternalmodels",
],
),"""
after = """linter=LinterConfig(enabled=True, rules=["nomissingunittest"]),"""
read_file = read_file.replace(before, after)
assert after in read_file
with open(sushi_path / "config.py", "w") as f:
f.writelines(read_file)

# Load the context with the temporary sushi path
context = Context(paths=[sushi_path])

# Lint the models
lints = context.lint_models(raise_on_error=False)

# Should have violations for models without tests (most models except customers)
assert len(lints) >= 1

# Check that we get violations for models without tests
violation_messages = [lint.violation_msg for lint in lints]
assert any("is missing unit test(s)" in msg for msg in violation_messages)

# Check that models with existing tests don't have violations
models_with_tests = ["customer_revenue_by_day", "customer_revenue_lifetime", "order_items"]

for model_name in models_with_tests:
model_violations = [
lint
for lint in lints
if model_name in lint.violation_msg and "is missing unit test(s)" in lint.violation_msg
]
assert len(model_violations) == 0, (
f"Model {model_name} should not have a violation since it has a test"
)
29 changes: 23 additions & 6 deletions tests/core/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,9 @@ def test_gateway(copy_to_temp_path: t.Callable, mocker: MockerFixture) -> None:
with open(test_path, "w", encoding="utf-8") as file:
dump_yaml(test_dict, file)

# Re-initialize context to pick up the modified test file
context = Context(paths=path, config=config)

spy_execute = mocker.spy(EngineAdapter, "_execute")
mocker.patch("sqlmesh.core.test.definition.random_id", return_value="jzngz56a")

Expand Down Expand Up @@ -2448,6 +2451,9 @@ def copy_test_file(test_file: Path, new_test_file: Path, index: int) -> None:
copy_test_file(original_test_file, tmp_path / "tests" / f"test_success_{i}.yaml", i)
copy_test_file(new_test_file, tmp_path / "tests" / f"test_failure_{i}.yaml", i)

# Re-initialize context to pick up the new test files
context = Context(paths=tmp_path, config=config)

with capture_output() as captured_output:
context.test()

Expand All @@ -2463,13 +2469,12 @@ def copy_test_file(test_file: Path, new_test_file: Path, index: int) -> None:
"SELECT 1 AS col_1, 2 AS col_2, 3 AS col_3, 4 AS col_4, 5 AS col_5, 6 AS col_6, 7 AS col_7"
)

context.upsert_model(
_create_model(
meta="MODEL(name test.test_wide_model)",
query=wide_model_query,
default_catalog=context.default_catalog,
)
wide_model = _create_model(
meta="MODEL(name test.test_wide_model)",
query=wide_model_query,
default_catalog=context.default_catalog,
)
context.upsert_model(wide_model)

tests_dir = tmp_path / "tests"
tests_dir.mkdir()
Expand All @@ -2493,6 +2498,9 @@ def copy_test_file(test_file: Path, new_test_file: Path, index: int) -> None:

wide_test_file.write_text(wide_test_file_content)

context.load()
context.upsert_model(wide_model)

with capture_output() as captured_output:
context.test()

Expand Down Expand Up @@ -2549,6 +2557,9 @@ def copy_test_file(test_file: Path, new_test_file: Path, index: int) -> None:
"""
)

# Re-initialize context to pick up the modified test file
context = Context(paths=tmp_path, config=config)

with capture_output() as captured_output:
context.test()

Expand Down Expand Up @@ -3472,6 +3483,9 @@ def test_cte_failure(tmp_path: Path) -> None:
"""
)

# Re-initialize context to pick up the new test file
context = Context(paths=tmp_path, config=config)

with capture_output() as captured_output:
context.test()

Expand All @@ -3498,6 +3512,9 @@ def test_cte_failure(tmp_path: Path) -> None:
"""
)

# Re-initialize context to pick up the modified test file
context = Context(paths=tmp_path, config=config)

with capture_output() as captured_output:
context.test()

Expand Down
Loading