diff --git a/bigframes/__init__.py b/bigframes/__init__.py index 240608ebc2..f282e092e4 100644 --- a/bigframes/__init__.py +++ b/bigframes/__init__.py @@ -16,12 +16,24 @@ from bigframes._config import option_context, options from bigframes._config.bigquery_options import BigQueryOptions +from bigframes._magics import _cell_magic from bigframes.core.global_session import close_session, get_global_session import bigframes.enums as enums import bigframes.exceptions as exceptions from bigframes.session import connect, Session from bigframes.version import __version__ +_MAGIC_NAMES = ["bigquery_sql"] + + +def load_ipython_extension(ipython): + """Called by IPython when this module is loaded as an IPython extension.""" + for magic_name in _MAGIC_NAMES: + ipython.register_magic_function( + _cell_magic, magic_kind="cell", magic_name=magic_name + ) + + __all__ = [ "options", "BigQueryOptions", diff --git a/bigframes/_magics.py b/bigframes/_magics.py new file mode 100644 index 0000000000..33205e3a37 --- /dev/null +++ b/bigframes/_magics.py @@ -0,0 +1,56 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from IPython.core import magic_arguments # type: ignore +from IPython.core.getipython import get_ipython +from IPython.display import display + +import bigframes.pandas + + +@magic_arguments.magic_arguments() +@magic_arguments.argument( + "destination_var", + nargs="?", + help=("If provided, save the output to this variable instead of displaying it."), +) +@magic_arguments.argument( + "--dry_run", + action="store_true", + default=False, + help=( + "Sets query to be a dry run to estimate costs. " + "Defaults to executing the query instead of dry run if this argument is not used." + "Does not work with engine 'bigframes'. " + ), +) +def _cell_magic(line, cell): + ipython = get_ipython() + args = magic_arguments.parse_argstring(_cell_magic, line) + if not cell: + print("Query is missing.") + return + pyformat_args = ipython.user_ns + dataframe = bigframes.pandas._read_gbq_colab( + cell, pyformat_args=pyformat_args, dry_run=args.dry_run + ) + if args.destination_var: + ipython.push({args.destination_var: dataframe}) + else: + with bigframes.option_context( + "display.repr_mode", + "anywidget", + ): + display(dataframe) + return diff --git a/bigframes/pandas/io/api.py b/bigframes/pandas/io/api.py index 483bc5e530..b8ad6cbd0c 100644 --- a/bigframes/pandas/io/api.py +++ b/bigframes/pandas/io/api.py @@ -49,6 +49,7 @@ import pyarrow as pa import bigframes._config as config +import bigframes._importing import bigframes.core.global_session as global_session import bigframes.core.indexes import bigframes.dataframe @@ -356,8 +357,12 @@ def _read_gbq_colab( with warnings.catch_warnings(): # Don't warning about Polars in SQL cell. # Related to b/437090788. - warnings.simplefilter("ignore", bigframes.exceptions.PreviewWarning) - config.options.bigquery.enable_polars_execution = True + try: + bigframes._importing.import_polars() + warnings.simplefilter("ignore", bigframes.exceptions.PreviewWarning) + config.options.bigquery.enable_polars_execution = True + except ImportError: + pass # don't fail if polars isn't available return global_session.with_default_session( bigframes.session.Session._read_gbq_colab, diff --git a/tests/system/small/test_magics.py b/tests/system/small/test_magics.py new file mode 100644 index 0000000000..aed759e4ac --- /dev/null +++ b/tests/system/small/test_magics.py @@ -0,0 +1,95 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from IPython.testing.globalipapp import get_ipython +from IPython.utils.capture import capture_output +import pandas as pd +import pytest + +import bigframes +import bigframes.pandas as bpd + +MAGIC_NAME = "bigquery_sql" + + +@pytest.fixture(scope="module") +def ip(): + """Provides a persistent IPython shell instance for the test session.""" + shell = get_ipython() + shell.extension_manager.load_extension("bigframes") + return shell + + +def test_magic_select_lit_to_var(ip): + bigframes.close_session() + + line = "dst_var" + cell_body = "SELECT 3" + + ip.run_cell_magic(MAGIC_NAME, line, cell_body) + + assert "dst_var" in ip.user_ns + result_df = ip.user_ns["dst_var"] + assert result_df.shape == (1, 1) + assert result_df.loc[0, 0] == 3 + + +def test_magic_select_lit_dry_run(ip): + bigframes.close_session() + + line = "dst_var --dry_run" + cell_body = "SELECT 3" + + ip.run_cell_magic(MAGIC_NAME, line, cell_body) + + assert "dst_var" in ip.user_ns + result_df = ip.user_ns["dst_var"] + assert result_df.totalBytesProcessed == 0 + + +def test_magic_select_lit_display(ip): + bigframes.close_session() + + cell_body = "SELECT 3" + + with capture_output() as io: + ip.run_cell_magic(MAGIC_NAME, "", cell_body) + assert len(io.outputs) > 0 + html_data = io.outputs[0].data["text/html"] + assert "[1 rows x 1 columns in total]" in html_data + + +def test_magic_select_interpolate(ip): + bigframes.close_session() + df = bpd.read_pandas( + pd.DataFrame({"col_a": [1, 2, 3, 4, 5, 6], "col_b": [1, 2, 1, 3, 1, 2]}) + ) + const_val = 1 + + ip.push({"df": df, "const_val": const_val}) + + query = """ + SELECT + SUM(col_a) AS total + FROM + {df} + WHERE col_b={const_val} + """ + + ip.run_cell_magic(MAGIC_NAME, "dst_var", query) + + assert "dst_var" in ip.user_ns + result_df = ip.user_ns["dst_var"] + assert result_df.shape == (1, 1) + assert result_df.loc[0, 0] == 9