diff --git a/core/opengate_core/opengate_core.cpp b/core/opengate_core/opengate_core.cpp index b0c02b93e..d8b636e46 100644 --- a/core/opengate_core/opengate_core.cpp +++ b/core/opengate_core/opengate_core.cpp @@ -307,6 +307,8 @@ void init_GateTrackCreatorProcessFilter(py::module &); void init_GateKineticEnergyFilter(py::module &); +void init_GateFilterData(py::module &); + // Gate actors void init_GateDoseActor(py::module &m); @@ -578,6 +580,7 @@ PYBIND11_MODULE(opengate_core, m) { init_GatePrimaryScatterFilter(m); init_GateTrackCreatorProcessFilter(m); init_GateKineticEnergyFilter(m); + init_GateFilterData(m); init_GateThresholdAttributeFilter(m); init_itk_image(m); init_GateImageNestedParameterisation(m); diff --git a/core/opengate_core/opengate_lib/GateFilterData.cpp b/core/opengate_core/opengate_lib/GateFilterData.cpp new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/core/opengate_core/opengate_lib/GateFilterData.cpp @@ -0,0 +1 @@ + diff --git a/core/opengate_core/opengate_lib/GateFilterData.h b/core/opengate_core/opengate_lib/GateFilterData.h new file mode 100644 index 000000000..e82306dd7 --- /dev/null +++ b/core/opengate_core/opengate_lib/GateFilterData.h @@ -0,0 +1,431 @@ +#ifndef OPENGATE_CORE_OPENGATE_LIB_GATEFILTERDATA_H +#define OPENGATE_CORE_OPENGATE_LIB_GATEFILTERDATA_H + +#include +#include +#include + +#include "GatePrimaryScatterFilter.h" +#include "GateUniqueVolumeIDManager.h" +#include "GateUserEventInformation.h" + +namespace attr { + +// Energy +struct TotalEnergyDeposit; +struct PostKineticEnergy; +struct PreKineticEnergy; +struct KineticEnergy; +struct TrackVertexKineticEnergy; +struct EventKineticEnergy; + +// Time +struct LocalTime; +struct GlobalTime; +struct PreGlobalTime; +struct TimeFromBeginOfEvent; +struct TrackProperTime; + +// Misc +struct Weight; +struct TrackID; +struct ParentID; +struct EventID; +struct RunID; +struct ThreadID; +struct TrackCreatorProcess; +struct TrackCreatorModelName; +struct TrackCreatorModelIndex; +struct ProcessDefinedStep; +struct ParticleName; +struct ParentParticleName; +struct ParticleType; +struct TrackVolumeName; +struct TrackVolumeCopyNo; +struct PreStepVolumeCopyNo; +struct PostStepVolumeCopyNo; +struct TrackVolumeInstanceID; +struct PreStepUniqueVolumeID; +struct PostStepUniqueVolumeID; +struct PDGCode; +struct HitUniqueVolumeID; + +// Position +struct Position; +struct PostPosition; +struct PrePosition; +struct PrePositionLocal; +struct PostPositionLocal; +struct EventPosition; +struct TrackVertexPosition; + +// Direction +struct Direction; +struct PostDirection; +struct PreDirection; +struct PreDirectionLocal; +struct TrackVertexMomentumDirection; +struct EventDirection; + +// Polarization +struct Polarization; + +// Length +struct StepLength; +struct TrackLength; + +// Scatter information +struct UnscatteredPrimaryFlag; + +} // namespace attr + +template struct GetAttr; + +// Energy +template <> struct GetAttr { + static double get(G4Step *step) { return step->GetTotalEnergyDeposit(); } +}; + +template <> struct GetAttr { + static double get(G4Step *step) { + return step->GetPostStepPoint()->GetKineticEnergy(); + } +}; + +template <> struct GetAttr { + static double get(G4Step *step) { + return step->GetPreStepPoint()->GetKineticEnergy(); + } +}; + +template <> struct GetAttr { + static double get(G4Step *step) { + return step->GetPreStepPoint()->GetKineticEnergy(); + } +}; + +template <> struct GetAttr { + static double get(G4Step *step) { + return step->GetTrack()->GetVertexKineticEnergy(); + } +}; + +template <> struct GetAttr { + static double get(G4Step *) { + auto const *event = G4RunManager::GetRunManager()->GetCurrentEvent(); + return event->GetPrimaryVertex(0)->GetPrimary(0)->GetKineticEnergy(); + } +}; + +// Time +template <> struct GetAttr { + static double get(G4Step *step) { + return step->GetPostStepPoint()->GetLocalTime(); + } +}; + +template <> struct GetAttr { + static double get(G4Step *step) { + return step->GetPostStepPoint()->GetGlobalTime(); + } +}; + +template <> struct GetAttr { + static double get(G4Step *step) { + return step->GetPreStepPoint()->GetGlobalTime(); + } +}; + +template <> struct GetAttr { + static double get(G4Step *step) { + auto const *event = G4RunManager::GetRunManager()->GetCurrentEvent(); + auto const globalTime = step->GetTrack()->GetGlobalTime(); + return globalTime - event->GetPrimaryVertex(0)->GetT0(); + } +}; + +template <> struct GetAttr { + static double get(G4Step *step) { return step->GetTrack()->GetProperTime(); } +}; + +// Misc +template <> struct GetAttr { + static double get(G4Step *step) { return step->GetTrack()->GetWeight(); } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + return step->GetTrack()->GetTrackID(); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + return step->GetTrack()->GetParentID(); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *) { + return G4RunManager::GetRunManager()->GetCurrentEvent()->GetEventID(); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *) { + return G4RunManager::GetRunManager()->GetCurrentRun()->GetRunID(); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *) { return G4Threading::G4GetThreadId(); } +}; + +template <> struct GetAttr { + static std::string get(G4Step *step) { + auto const *creatorProcess = step->GetTrack()->GetCreatorProcess(); + if (creatorProcess) + return creatorProcess->GetProcessName(); + return "none"; + } +}; + +template <> struct GetAttr { + static std::string get(G4Step *step) { + return step->GetTrack()->GetCreatorModelName(); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + return step->GetTrack()->GetCreatorModelIndex(); + } +}; + +template <> struct GetAttr { + static std::string get(G4Step *step) { + auto const *p = step->GetPreStepPoint()->GetProcessDefinedStep(); + if (p) + return p->GetProcessName(); + return "none"; + } +}; + +template <> struct GetAttr { + static std::string get(G4Step *step) { + return step->GetTrack()->GetParticleDefinition()->GetParticleName(); + } +}; + +template <> struct GetAttr { + static std::string get(G4Step *step) { + auto const *event = G4RunManager::GetRunManager()->GetCurrentEvent(); + auto const *info = + dynamic_cast(event->GetUserInformation()); + if (info) { + auto const trackId = step->GetTrack()->GetParentID(); + return info->GetParticleName(trackId); + } + return "no_user_event_info"; + } +}; + +template <> struct GetAttr { + static std::string get(G4Step *step) { + return step->GetTrack()->GetParticleDefinition()->GetParticleType(); + } +}; + +template <> struct GetAttr { + static std::string get(G4Step *step) { + return step->GetTrack()->GetVolume()->GetName(); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + return step->GetTrack()->GetVolume()->GetCopyNo(); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const *touchable = step->GetPreStepPoint()->GetTouchable(); + auto const depth = touchable->GetHistoryDepth(); + return touchable->GetVolume(depth)->GetCopyNo(); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const *touchable = step->GetPostStepPoint()->GetTouchable(); + auto const depth = touchable->GetHistoryDepth(); + return touchable->GetVolume(depth)->GetCopyNo(); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + return step->GetTrack()->GetVolume()->GetInstanceID(); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto *manager = GateUniqueVolumeIDManager::GetInstance(); + auto const *touchable = step->GetPreStepPoint()->GetTouchable(); + return manager->GetVolumeID(touchable); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto *manager = GateUniqueVolumeIDManager::GetInstance(); + auto const *touchable = step->GetPostStepPoint()->GetTouchable(); + return manager->GetVolumeID(touchable); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + return step->GetTrack()->GetParticleDefinition()->GetPDGEncoding(); + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto *manager = GateUniqueVolumeIDManager::GetInstance(); + auto const processName = + step->GetPostStepPoint()->GetProcessDefinedStep()->GetProcessName(); + if (processName == "Transportation") + return manager->GetVolumeID(step->GetPreStepPoint()->GetTouchable()); + else + return manager->GetVolumeID(step->GetPostStepPoint()->GetTouchable()); + } +}; + +// Position +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const pos = step->GetPostStepPoint()->GetPosition(); + return std::vector{pos}; + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const pos = step->GetPostStepPoint()->GetPosition(); + return std::vector{pos}; + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const pos = step->GetPreStepPoint()->GetPosition(); + return std::vector{pos}; + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const *touchable = step->GetPreStepPoint()->GetTouchable(); + auto pos = step->GetPreStepPoint()->GetPosition(); + touchable->GetHistory()->GetTopTransform().ApplyPointTransform(pos); + return std::vector{pos}; + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const *touchable = step->GetPostStepPoint()->GetTouchable(); + auto pos = step->GetPostStepPoint()->GetPosition(); + touchable->GetHistory()->GetTopTransform().ApplyPointTransform(pos); + return std::vector{pos}; + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *) { + auto const *event = G4RunManager::GetRunManager()->GetCurrentEvent(); + auto const pos = event->GetPrimaryVertex(0)->GetPosition(); + return std::vector{pos}; + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const pos = step->GetTrack()->GetVertexPosition(); + return std::vector{pos}; + } +}; + +// Direction +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const dir = step->GetPostStepPoint()->GetMomentumDirection(); + return std::vector{dir}; + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const dir = step->GetPostStepPoint()->GetMomentumDirection(); + return std::vector{dir}; + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const dir = step->GetPreStepPoint()->GetMomentumDirection(); + return std::vector{dir}; + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const *touchable = step->GetPostStepPoint()->GetTouchable(); + auto dir = step->GetPreStepPoint()->GetMomentumDirection(); + touchable->GetHistory()->GetTopTransform().TransformAxis(dir); + return std::vector{dir}; + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const dir = step->GetTrack()->GetVertexMomentumDirection(); + return std::vector{dir}; + } +}; + +template <> struct GetAttr { + static decltype(auto) get(G4Step *) { + auto const *event = G4RunManager::GetRunManager()->GetCurrentEvent(); + auto const dir = + event->GetPrimaryVertex(0)->GetPrimary(0)->GetMomentumDirection(); + return std::vector{dir}; + } +}; + +// Polarization +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { + auto const pol = step->GetTrack()->GetPolarization(); + return std::vector{pol}; + } +}; + +// Length +template <> struct GetAttr { + static double get(G4Step *step) { return step->GetStepLength(); } +}; + +template <> struct GetAttr { + static double get(G4Step *step) { return step->GetTrack()->GetTrackLength(); } +}; + +// Scatter information +template <> struct GetAttr { + static decltype(auto) get(G4Step *step) { return IsUnscatteredPrimary(step); } +}; + +#endif diff --git a/core/opengate_core/opengate_lib/GateUserEventInformation.cpp b/core/opengate_core/opengate_lib/GateUserEventInformation.cpp index c09b8623f..c26b92f07 100644 --- a/core/opengate_core/opengate_lib/GateUserEventInformation.cpp +++ b/core/opengate_core/opengate_lib/GateUserEventInformation.cpp @@ -12,7 +12,7 @@ void GateUserEventInformation::Print() const { // FIXME } -std::string GateUserEventInformation::GetParticleName(G4int track_id) { +std::string GateUserEventInformation::GetParticleName(G4int track_id) const { if (fMapOfParticleName.count(track_id) > 0) { return fMapOfParticleName.at(track_id); } else diff --git a/core/opengate_core/opengate_lib/GateUserEventInformation.h b/core/opengate_core/opengate_lib/GateUserEventInformation.h index 138dd8067..056be2ced 100644 --- a/core/opengate_core/opengate_lib/GateUserEventInformation.h +++ b/core/opengate_core/opengate_lib/GateUserEventInformation.h @@ -21,7 +21,7 @@ class GateUserEventInformation : public G4VUserEventInformation { void Print() const override; - std::string GetParticleName(G4int track_id); + std::string GetParticleName(G4int track_id) const; void BeginOfEventAction(const G4Event *event); diff --git a/core/opengate_core/opengate_lib/GateVActor.cpp b/core/opengate_core/opengate_lib/GateVActor.cpp index 68064703c..2fb020907 100644 --- a/core/opengate_core/opengate_lib/GateVActor.cpp +++ b/core/opengate_core/opengate_lib/GateVActor.cpp @@ -110,14 +110,14 @@ bool GateVActor::HasAction(const std::string &action) { bool GateVActor::IsSensitiveDetector() { return HasAction("SteppingAction"); }; void GateVActor::PreUserTrackingAction(const G4Track *track) { - for (auto f : fFilters) { + for (auto f : fFilters) { // TODO: does not seem to do anything if (!f->Accept(track)) return; } } void GateVActor::PostUserTrackingAction(const G4Track *track) { - for (auto f : fFilters) { + for (auto f : fFilters) { // TODO: does not seem to do anything if (!f->Accept(track)) return; } @@ -140,6 +140,24 @@ G4bool GateVActor::ProcessHits(G4Step *step, G4TouchableHistory *) { => so we decide to simplify and remove "touchable" in the following. */ + // if using C++-compiled function + // if (fFilter) { + // if (fFilter(reinterpret_cast(step))) + // SteppingAction(step); + // return true; + // } + + // else, if using Python pybind function + if (fFilter) { + py::gil_scoped_acquire gil_acquire; + auto accept = fFilter(step); + if (py::cast(accept)) { + // py::gil_scoped_release gil_release; + SteppingAction(step); + } + return true; + } + // if the operator is AND, we perform the SteppingAction only if ALL filters // are true (If only one is false, we stop and return) if (fOperatorIsAnd) { diff --git a/core/opengate_core/opengate_lib/GateVActor.h b/core/opengate_core/opengate_lib/GateVActor.h index 92abaaa32..1d84be4dc 100644 --- a/core/opengate_core/opengate_lib/GateVActor.h +++ b/core/opengate_core/opengate_lib/GateVActor.h @@ -12,6 +12,7 @@ #include #include #include +#include #include namespace py = pybind11; @@ -162,6 +163,12 @@ class GateVActor : public G4VPrimitiveScorer { bool fWriteToDisk; GateSourceManager *fSourceManager; + + // Filter function + // using FilterFunction = std::function; + using FilterFunction = py::function; + FilterFunction fFilter; + void SetFilterFunction(FilterFunction filter) { fFilter = filter; } }; #endif // GateVActor_h diff --git a/core/opengate_core/opengate_lib/pyGateFilterData.cpp b/core/opengate_core/opengate_lib/pyGateFilterData.cpp new file mode 100644 index 000000000..288b7979c --- /dev/null +++ b/core/opengate_core/opengate_lib/pyGateFilterData.cpp @@ -0,0 +1,87 @@ +#include "GateFilterData.h" +#include +#include + +namespace py = pybind11; + +void init_GateFilterData(py::module &m) { + m + // Energy + .def("GetAttrTotalEnergyDeposit", &GetAttr::get) + .def("GetAttrPostKineticEnergy", &GetAttr::get) + .def("GetAttrPreKineticEnergy", &GetAttr::get) + .def("GetAttrKineticEnergy", &GetAttr::get) + .def("GetAttrTrackVertexKineticEnergy", + &GetAttr::get) + .def("GetAttrEventKineticEnergy", &GetAttr::get) + + // Time + .def("GetAttrLocalTime", &GetAttr::get) + .def("GetAttrGlobalTime", &GetAttr::get) + .def("GetAttrPreGlobalTime", &GetAttr::get) + .def("GetAttrTimeFromBeginOfEvent", + &GetAttr::get) + .def("GetAttrTrackProperTime", &GetAttr::get) + + // Misc + .def("GetAttrWeight", &GetAttr::get) + .def("GetAttrTrackID", &GetAttr::get) + .def("GetAttrParentID", &GetAttr::get) + .def("GetAttrEventID", &GetAttr::get) + .def("GetAttrRunID", &GetAttr::get) + .def("GetAttrThreadID", &GetAttr::get) + .def("GetAttrTrackCreatorProcess", + &GetAttr::get) + .def("GetAttrTrackCreatorModelName", + &GetAttr::get) + .def("GetAttrTrackCreatorModelIndex", + &GetAttr::get) + .def("GetAttrProcessDefinedStep", &GetAttr::get) + .def("GetAttrParticleName", &GetAttr::get) + .def("GetAttrParentParticleName", &GetAttr::get) + .def("GetAttrParticleType", &GetAttr::get) + .def("GetAttrTrackVolumeName", &GetAttr::get) + .def("GetAttrTrackVolumeCopyNo", &GetAttr::get) + .def("GetAttrPreStepVolumeCopyNo", + &GetAttr::get) + .def("GetAttrPostStepVolumeCopyNo", + &GetAttr::get) + .def("GetAttrTrackVolumeInstanceID", + &GetAttr::get) + .def("GetAttrPreStepUniqueVolumeID", + &GetAttr::get) + .def("GetAttrPostStepUniqueVolumeID", + &GetAttr::get) + .def("GetAttrPDGCode", &GetAttr::get) + .def("GetAttrHitUniqueVolumeID", &GetAttr::get) + + // Position + .def("GetAttrPosition", &GetAttr::get) + .def("GetAttrPostPosition", &GetAttr::get) + .def("GetAttrPrePosition", &GetAttr::get) + .def("GetAttrPrePositionLocal", &GetAttr::get) + .def("GetAttrPostPositionLocal", &GetAttr::get) + .def("GetAttrEventPosition", &GetAttr::get) + .def("GetAttrTrackVertexPosition", + &GetAttr::get) + + // Direction + .def("GetAttrDirection", &GetAttr::get) + .def("GetAttrPostDirection", &GetAttr::get) + .def("GetAttrPreDirection", &GetAttr::get) + .def("GetAttrPreDirectionLocal", &GetAttr::get) + .def("GetAttrTrackVertexMomentumDirection", + &GetAttr::get) + .def("GetAttrEventDirection", &GetAttr::get) + + // Polarization + .def("GetAttrPolarization", &GetAttr::get) + + // Length + .def("GetAttrStepLength", &GetAttr::get) + .def("GetAttrTrackLength", &GetAttr::get) + + // Scatter information + .def("GetAttrUnscatteredPrimaryFlag", + &GetAttr::get); +} diff --git a/core/opengate_core/opengate_lib/pyGateVActor.cpp b/core/opengate_core/opengate_lib/pyGateVActor.cpp index 69e1759d1..cd793fa62 100644 --- a/core/opengate_core/opengate_lib/pyGateVActor.cpp +++ b/core/opengate_core/opengate_lib/pyGateVActor.cpp @@ -81,6 +81,7 @@ void init_GateVActor(py::module &m) { // .def_readonly("fActions", &GateVActor::fActions) // avoid wrapping // this -> problems with pickle .def_readwrite("fFilters", &GateVActor::fFilters) + .def("SetFilterFunction", &GateVActor::SetFilterFunction) .def("Close", &GateVActor::Close) .def("InitializeCpp", &GateVActor::InitializeCpp) .def("InitializeUserInfo", &GateVActor::InitializeUserInfo) diff --git a/opengate/actors/base.py b/opengate/actors/base.py index d8eec6aee..bd4d1dada 100644 --- a/opengate/actors/base.py +++ b/opengate/actors/base.py @@ -1,3 +1,4 @@ +import ast from box import Box from functools import wraps @@ -6,6 +7,10 @@ from ..base import GateObject, process_cls from ..utility import insert_suffix_before_extension from .actoroutput import ActorOutputRoot +from ..filters.ast import FilterASTTransformer +import ROOT +import tempfile +import opengate_core def _setter_hook_attached_to(self, attached_to): @@ -87,6 +92,7 @@ class ActorBase(GateObject): # hints for IDE attached_to: str + filter: str filters: list filters_boolean_operator: str priority: int @@ -105,6 +111,12 @@ class ActorBase(GateObject): "deprecated": "The user input parameter 'mother' is deprecated. Use 'attached_to' instead. ", }, ), + "filter": ( + "", + { + "doc": "Filter used by this actor. ", + }, + ), "filters": ( [], { @@ -492,6 +504,96 @@ def initialize(self): f"Does the actor class somehow inherit from GateVActor (as it should)?" ) + # set filter function + if self.filter != "": + self._build_and_set_filter_function() + + def _build_and_set_filter_function(self): + tr = FilterASTTransformer() + e = ast.parse(self.filter, mode="eval") + e = tr.visit(e) + + name = "local_filter" + args = ast.arguments( + posonlyargs=[], + args=[ast.arg(arg="step")], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ) + body = [ast.Return(value=e.body)] + e = ast.FunctionDef(name=name, args=args, body=body, decorator_list=[]) + + imports = [ + ast.Import(names=[ast.alias(name="numpy", asname="np")]), + ast.Import(names=[ast.alias(name="opengate_core", asname=None)]), + ast.Import(names=[ast.alias(name="opengate.filters.ast", asname=None)]), + ast.ImportFrom( + module="filters.ast", + names=[ + ast.alias(name="dbgp", asname=None), + ], + level=2, + ), + ] + + module = ast.Module(body=[*imports, e], type_ignores=[]) + ast.fix_missing_locations(module) + + print(ast.dump(module)) + + exec(compile(module, filename="", mode="exec"), globals()) + + # include_paths = [ + # "/home/alexis/work/external/geant4/geant4-v11.3.0/source/track/include", + # "/home/alexis/work/external/geant4/geant4-v11.3.0/source/global/management/include", + # "/home/alexis/work/external/geant4/geant4-v11.3.0/source/particles/management/include", + # "/home/alexis/work/external/geant4/geant4-v11.3.0/source/geometry/management/include", + # "/home/alexis/work/external/geant4/geant4-v11.3.0/source/global/HEPGeometry/include", + # "/home/alexis/work/external/geant4/geant4-v11.3.0/source/materials/include", + # "/home/alexis/work/external/geant4/geant4-v11.3.0/build/source/global/include", + # ] + + # cpp_fun = """ + # bool cpp_local_filter(void* ptr) { + # auto* step = reinterpret_cast(ptr); + # return step->GetTrack()->GetParticleDefinition()->GetParticleName() == "proton"; + # } + # """ + # for include_path in include_paths: + # ROOT.gInterpreter.AddIncludePath(include_path) + # ROOT.gInterpreter.Declare("#include ") + # ROOT.gInterpreter.Declare("#include ") + # ROOT.gInterpreter.Declare("#include ") + # ROOT.gInterpreter.Declare(cpp_fun) + + # cpp_code = """ + # #include + # #include + # #include + # + # bool cpp_local_filter(void* ptr) { + # auto* step = reinterpret_cast(ptr); + # return step->GetTrack()->GetParticleDefinition()->GetParticleName() == "proton"; + # } + # """ + # + # cpp_file = "" + # with tempfile.NamedTemporaryFile(suffix=".C", mode="w", delete=False) as f: + # f.write(cpp_code) + # cpp_file = f.name + # + # for include_path in include_paths: + # ROOT.gSystem.AddIncludePath(f"-I{include_path}") + + # ROOT.gSystem.SetMakeSharedLib(f"g++ -O2 -Wall -fPIC -shared -o $LibName $ObjectFiles") + # ROOT.gSystem.CompileMacro(cpp_file, "kO") + # ROOT.gSystem.Load("/tmp/tmpfhe1bkxu_C.so") + + self.SetFilterFunction(local_filter) + # self.SetFilterFunction(ROOT.cpp_local_filter) + # self.SetFilterFunction(llocal_filter) + def _init_user_output_instance(self): for output_name, output_config in self._processed_user_output_config.items(): try: diff --git a/opengate/filters/ast.py b/opengate/filters/ast.py new file mode 100644 index 000000000..173a28a9b --- /dev/null +++ b/opengate/filters/ast.py @@ -0,0 +1,174 @@ +import ast +import inspect +import sys +from typing import Callable, Any +from ..utility import g4_units +import opengate_core as g4 + + +# make local variables for each G4 unit +# TODO define guideline for unit naming when * or / +for key in g4_units: + locals().update({key: g4_units[key]}) + + +def dbgp(s): + print(f"[dbgp] {s}", file=sys.stderr) + return True + + +class Attribute: + name: str + get: Callable[[g4.G4Step], Any] + + def __init__(self, name, get): + self.name = name + self.get = get + + +################################################## +# Energy +total_energy_deposit = Attribute("total_energy_deposit", g4.GetAttrTotalEnergyDeposit) +post_kinetic_energy = Attribute("post_kinetic_energy", g4.GetAttrPostKineticEnergy) +pre_kinetic_energy = Attribute("pre_kinetic_energy", g4.GetAttrPreKineticEnergy) +kinetic_energy = Attribute("kinetic_energy", g4.GetAttrKineticEnergy) +track_vertex_kinetic_energy = Attribute( + "track_vertex_kinetic_energy", g4.GetAttrTrackVertexKineticEnergy +) +event_kinetic_energy = Attribute("event_kinetic_energy", g4.GetAttrEventKineticEnergy) + +# Time +local_time = Attribute("local_time", g4.GetAttrLocalTime) +global_time = Attribute("global_time", g4.GetAttrGlobalTime) +pre_global_time = Attribute("pre_global_time", g4.GetAttrPreGlobalTime) +time_from_begin_of_event = Attribute( + "time_from_begin_of_event", g4.GetAttrTimeFromBeginOfEvent +) +track_proper_time = Attribute("track_proper_time", g4.GetAttrTrackProperTime) + +# Misc +weight = Attribute("weight", g4.GetAttrWeight) +track_id = Attribute("track_id", g4.GetAttrTrackID) +parent_id = Attribute("parent_id", g4.GetAttrParentID) +event_id = Attribute("event_id", g4.GetAttrEventID) +run_id = Attribute("run_id", g4.GetAttrRunID) +thread_id = Attribute("thread_id", g4.GetAttrThreadID) +track_creator_process = Attribute( + "track_creator_process", g4.GetAttrTrackCreatorProcess +) +track_creator_model_name = Attribute( + "track_creator_model_name", g4.GetAttrTrackCreatorModelName +) +track_creator_model_index = Attribute( + "track_creator_model_index", g4.GetAttrTrackCreatorModelIndex +) +process_defined_step = Attribute("process_defined_step", g4.GetAttrProcessDefinedStep) +particle_name = Attribute("particle_name", g4.GetAttrParticleName) +parent_particle_name = Attribute("parent_particle_name", g4.GetAttrParentParticleName) +particle_type = Attribute("particle_type", g4.GetAttrParticleType) +track_volume_name = Attribute("track_volume_name", g4.GetAttrTrackVolumeName) +track_volume_copy_no = Attribute("track_volume_copy_no", g4.GetAttrTrackVolumeCopyNo) +pre_step_volume_copy_no = Attribute( + "pre_step_volume_copy_no", g4.GetAttrPreStepVolumeCopyNo +) +post_step_volume_copy_no = Attribute( + "post_step_volume_copy_no", g4.GetAttrPostStepVolumeCopyNo +) +track_volume_instance_id = Attribute( + "track_volume_instance_id", g4.GetAttrTrackVolumeInstanceID +) +pre_step_unique_volume_id = Attribute( + "pre_step_unique_volume_id", g4.GetAttrPreStepUniqueVolumeID +) +post_step_unique_volume_id = Attribute( + "post_step_unique_volume_id", g4.GetAttrPostStepUniqueVolumeID +) +pdg_code = Attribute("pdg_code", g4.GetAttrPDGCode) +hit_unique_volume_id = Attribute("hit_unique_volume_id", g4.GetAttrHitUniqueVolumeID) + +# Position +position = Attribute("position", g4.GetAttrPosition) +post_position = Attribute("post_position", g4.GetAttrPostPosition) +pre_position = Attribute("pre_position", g4.GetAttrPrePosition) +pre_position_local = Attribute("pre_position_local", g4.GetAttrPrePositionLocal) +post_position_local = Attribute("post_position_local", g4.GetAttrPostPositionLocal) +event_position = Attribute("event_position", g4.GetAttrEventPosition) +track_vertex_position = Attribute( + "track_vertex_position", g4.GetAttrTrackVertexPosition +) + +# Direction +direction = Attribute("direction", g4.GetAttrDirection) +post_direction = Attribute("post_direction", g4.GetAttrPostDirection) +pre_direction = Attribute("pre_direction", g4.GetAttrPreDirection) +pre_direction_local = Attribute("pre_direction_local", g4.GetAttrPreDirectionLocal) +track_vertex_momentum_direction = Attribute( + "track_vertex_momentum_direction", g4.GetAttrTrackVertexMomentumDirection +) +event_direction = Attribute("event_direction", g4.GetAttrEventDirection) + +# Polarization +polarization = Attribute("polarization", g4.GetAttrPolarization) + +# Length +step_length = Attribute("step_length", g4.GetAttrStepLength) +track_length = Attribute("track_length", g4.GetAttrTrackLength) + +# Scatter information +unscattered_primary_flag = Attribute( + "unscattered_primary_flag", g4.GetAttrUnscatteredPrimaryFlag +) +################################################## + + +class FilterASTTransformer(ast.NodeTransformer): + def fn_name_from_attr(self, node): + f_src = inspect.getsource(eval(ast.unparse(node)).get) + tree = ast.parse(f_src) + f_name = tree.body[0].name + return f_name + + def visit_Name(self, node): + node_type = eval(f"type({node.id})") + if node_type is Attribute: + # Method 1 -- using local functions + # f_name = self.fn_name_from_attr(node) + # func = ast.Name(id=f"{f_name}", ctx=ast.Load()) + + # Method 2 -- using opengate_core functions + fn = eval(ast.unparse(node)).get + print(fn.__module__) + + func = ast.Attribute( + value=ast.Name(id="opengate_core", ctx=ast.Load()), + attr=fn.__name__, + ctx=ast.Load(), + ) + + args = [ast.Name(id="step", ctx=ast.Load())] + + return ast.Call(func=func, args=args, keywords=[]) + else: + value_types = (bool, int, float, str, type(None)) + value = eval(ast.unparse(node)) + if isinstance(value, value_types): + return ast.Constant(value=value) + + return node + + # resolve constant computations (e.g. 5 * MeV) + def visit_BinOp(self, node: ast.BinOp): + self.generic_visit(node) + if isinstance(node.left, ast.Constant) and isinstance(node.right, ast.Constant): + value = eval(ast.unparse(node)) + return ast.Constant(value=value) + else: + return node + + def visit_Expr(self, node: ast.Expr): + self.generic_visit(node) + return node + + def visit_Compare(self, node: ast.Compare): + self.generic_visit(node) + return node diff --git a/opengate/tests/src/output/dose_dose.mhd b/opengate/tests/src/output/dose_dose.mhd new file mode 100644 index 000000000..fc10ae3ad --- /dev/null +++ b/opengate/tests/src/output/dose_dose.mhd @@ -0,0 +1,13 @@ +ObjectType = Image +NDims = 3 +BinaryData = True +BinaryDataByteOrderMSB = False +CompressedData = False +TransformMatrix = 1 0 0 0 1 0 0 0 1 +Offset = -150 -150 -154 +CenterOfRotation = 0 0 0 +AnatomicalOrientation = RAI +ElementSpacing = 10 10 2 +DimSize = 31 31 155 +ElementType = MET_DOUBLE +ElementDataFile = dose_dose.raw diff --git a/opengate/tests/src/output/dose_dose.raw b/opengate/tests/src/output/dose_dose.raw new file mode 100644 index 000000000..f0160fd46 Binary files /dev/null and b/opengate/tests/src/output/dose_dose.raw differ diff --git a/opengate/tests/src/output/dose_dose_uncertainty.mhd b/opengate/tests/src/output/dose_dose_uncertainty.mhd new file mode 100644 index 000000000..3620fef54 --- /dev/null +++ b/opengate/tests/src/output/dose_dose_uncertainty.mhd @@ -0,0 +1,13 @@ +ObjectType = Image +NDims = 3 +BinaryData = True +BinaryDataByteOrderMSB = False +CompressedData = False +TransformMatrix = 1 0 0 0 1 0 0 0 1 +Offset = -150 -150 -154 +CenterOfRotation = 0 0 0 +AnatomicalOrientation = RAI +ElementSpacing = 10 10 2 +DimSize = 31 31 155 +ElementType = MET_DOUBLE +ElementDataFile = dose_dose_uncertainty.raw diff --git a/opengate/tests/src/output/dose_dose_uncertainty.raw b/opengate/tests/src/output/dose_dose_uncertainty.raw new file mode 100644 index 000000000..991be688e Binary files /dev/null and b/opengate/tests/src/output/dose_dose_uncertainty.raw differ diff --git a/opengate/tests/src/output/dose_edep.mhd b/opengate/tests/src/output/dose_edep.mhd new file mode 100644 index 000000000..c6e3ccca5 --- /dev/null +++ b/opengate/tests/src/output/dose_edep.mhd @@ -0,0 +1,13 @@ +ObjectType = Image +NDims = 3 +BinaryData = True +BinaryDataByteOrderMSB = False +CompressedData = False +TransformMatrix = 1 0 0 0 1 0 0 0 1 +Offset = -150 -150 -154 +CenterOfRotation = 0 0 0 +AnatomicalOrientation = RAI +ElementSpacing = 10 10 2 +DimSize = 31 31 155 +ElementType = MET_DOUBLE +ElementDataFile = dose_edep.raw diff --git a/opengate/tests/src/output/dose_edep.raw b/opengate/tests/src/output/dose_edep.raw new file mode 100644 index 000000000..49b81f728 Binary files /dev/null and b/opengate/tests/src/output/dose_edep.raw differ diff --git a/opengate/tests/src/test023_filters_generic.py b/opengate/tests/src/test023_filters_generic.py new file mode 100644 index 000000000..d8c7f9724 --- /dev/null +++ b/opengate/tests/src/test023_filters_generic.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from box import Box +import click +import matplotlib.pyplot as plt +import numpy as np +import opengate as gate +import pathlib +import pyvista +import SimpleITK as sitk +from opengate.filters.ast import FilterASTTransformer +import ast + + +current_path = pathlib.Path(__file__).parent.resolve() +data_path = current_path / "data" +output_path = current_path / "output" +output_file = output_path / "dose.mhd" + +alpha_channel = -1 +colors = Box( + { + "invisible": [0, 0, 0, 0], + "red": [1, 0, 0, 1], + "green": [0, 1, 0, 1], + "blue": [0, 0, 1, 1], + "cyan": [0, 1, 1, 1], + "magenta": [1, 0, 1, 1], + "yellow": [1, 1, 0, 1], + "grey": [0.7, 0.7, 0.7, 1], + "white": [1, 1, 1, 1], + "pink": [1, 0.75, 0.79, 1], + "orange": [1, 0.5, 0, 1], + } +) + + +def simulation(n: int, visu=False): + # units + m = gate.g4_units.m + cm = gate.g4_units.cm + mm = gate.g4_units.mm + um = gate.g4_units.um + MeV = gate.g4_units.MeV + deg = gate.g4_units.deg + + tr = FilterASTTransformer() + e = ast.parse("particle_name == 'proton'") + e = tr.visit(e) + print(ast.dump(e)) + + # create the simulation + sim = gate.Simulation() + + sim.progress_bar = True + + # main user options + ui = sim.user_info + ui.g4_verbose = False + ui.g4_verbose_level = 1 + ui.visu = visu + ui.visu_type = "vrml_file_only" + ui.visu_filename = str(output_path / f"visu_{n}.wrl") + ui.random_seed = "auto" + ui.number_of_threads = 4 + + # change world size + world = sim.world + world.size = [5 * m, 5 * m, 5 * m] + world.color[alpha_channel] = 0 + + # water box + waterbox = sim.add_volume("Box", "waterbox") + waterbox.size = [31 * cm, 31 * cm, 31 * cm] + waterbox.translation = [0 * mm, 0 * mm, 0 * mm] + waterbox.material = "G4_WATER" + waterbox.set_max_step_size(0.1 * mm) + waterbox.color = colors.cyan + + # physics + sim.physics_manager.physics_list_name = "G4EmStandardPhysics_option4" + sim.physics_manager.set_production_cut("world", "gamma", 10 * m) + sim.physics_manager.set_production_cut("world", "electron", 10 * m) + sim.physics_manager.set_production_cut("world", "positron", 10 * m) + + if visu: + sim.physics_manager.set_production_cut(waterbox.name, "gamma", 1 * mm) + sim.physics_manager.set_production_cut(waterbox.name, "electron", 1 * mm) + sim.physics_manager.set_production_cut(waterbox.name, "positron", 1 * mm) + else: + sim.physics_manager.set_production_cut(waterbox.name, "gamma", 1 * um) + sim.physics_manager.set_production_cut(waterbox.name, "electron", 1 * um) + sim.physics_manager.set_production_cut(waterbox.name, "positron", 1 * um) + + sim.physics_manager.set_user_limits_particles(["gamma", "electron"]) + + # source + source = sim.add_source("GenericSource", "beam") + source.particle = "e-" + source.energy.mono = 20 * MeV + source.position.type = "point" + source.position.translation = [0 * mm, 0 * mm, 1 * m + 15.5 * cm] + source.direction.type = "iso" + source.direction.theta = [0 * deg, 3 * deg] # ZOX plane + source.direction.phi = [0 * deg, 360 * deg] # YOX plane + source.n = n / ui.number_of_threads + + # dose actor + dose = sim.add_actor("DoseActor", "dose") + dose.attached_to = waterbox + dose.output_filename = output_file + dose.size = [31, 31, 155] + dose.spacing = [1 * cm, 1 * cm, 2 * mm] + dose.hit_type = "random" + dose.dose.active = True + dose.dose_uncertainty.active = True + + fp = sim.add_filter("ParticleFilter", "fp") + fp.particle = "gamma" + + # dose.filters.append(fp) + + # dose.filter = "5 == 0" + dose.filter = "particle_name == 'gamma' and 5 == 5 or 2 * pre_kinetic_energy < 20 * MeV and dbgp((pre_kinetic_energy - post_kinetic_energy) / step_length)" + # dose.filter = "particle_name == 'gamma'" + + # add stat actor + stats = sim.add_actor("SimulationStatisticsActor", "stats") + stats.track_types_flag = True + + # start simulation + sim.run() + + # print results at the end + print(stats) + + +def analysis(): + img = sitk.ReadImage(str(output_file).replace(".mhd", "_dose.mhd")) + data = np.array(sitk.GetArrayFromImage(img)) + profile = [np.sum(plan) for plan in data] + profile = profile[::-1] # reverse order + profile = profile[: len(profile) // 2] + + # Dose profile figure + fig, ax = plt.subplots(figsize=(5.5, 3.8), dpi=300) + plt.title("Depth dose profile") + ax.set_xlabel("Depth (voxel)") + ax.set_ylabel("Dose (Gy)") + + ax.plot(profile) + + fig.savefig("depth_dose_profile.png") + plt.show() + plt.close(fig) + + +def visualisation(n: int): + pl = pyvista.Plotter() + pl.import_vrml(str(output_path / f"visu_{n}.wrl")) + pl.add_axes(line_width=5, color="white") + pl.background_color = "black" + for actor in pl.renderer.GetActors(): + actor.GetProperty().SetOpacity(0.7) + pl.show() + + +@click.command() +@click.option( + "-s", + "--sim", + "--simulation", + "enable_sim", + is_flag=True, + default=False, + help="enable simulation", +) +@click.option( + "-a", + "--analysis", + "enable_analysis", + is_flag=True, + default=False, + help="enable analysis", +) +@click.option( + "-V", + "--visu", + "--visualisation", + "enable_visu", + is_flag=True, + default=False, + help="enable visualisation", +) +@click.option( + "-n", "--primaries", "n", type=str, default="1e2", help="number of primaries" +) +def main(enable_sim: bool, enable_analysis: bool, enable_visu: bool, n: str): + n = int(float(n)) # handle scientific notation + + if enable_sim: + simulation(n, visu=enable_visu) + + if enable_analysis: + analysis() + + if enable_visu: + visualisation(n) + + +if __name__ == "__main__": + main()