diff --git a/api/controllers/linux/install/controller_test.go b/api/controllers/linux/install/controller_test.go index e6de4ee6c0..eac2163a3f 100644 --- a/api/controllers/linux/install/controller_test.go +++ b/api/controllers/linux/install/controller_test.go @@ -467,14 +467,14 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostConfigured, expectedState: states.StateHostPreflightsSucceeded, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(nil), - pm.On("GetHostPreflightOutput", mock.Anything).Return(successfulPreflightOutput, nil), - ) + // BEFORE: Had to mock all 4 methods even though test only validates the flow + // NOW: Only override what matters for this specific test case + // Default stubs handle: ClearHostPreflightResults (returns nil) + // Override PrepareHostPreflights to return the expected spec + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + // Default stubs handle: RunHostPreflights (returns nil) + // Override GetHostPreflightOutput to return successful result + pm.On("GetHostPreflightOutput", mock.Anything).Return(successfulPreflightOutput, nil) }, expectedErr: false, }, @@ -483,14 +483,10 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostPreflightsExecutionFailed, expectedState: states.StateHostPreflightsSucceeded, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(nil), - pm.On("GetHostPreflightOutput", mock.Anything).Return(successfulPreflightOutput, nil), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil), RunHostPreflights (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("GetHostPreflightOutput", mock.Anything).Return(successfulPreflightOutput, nil) }, expectedErr: false, }, @@ -499,14 +495,10 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostPreflightsFailed, expectedState: states.StateHostPreflightsSucceeded, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(nil), - pm.On("GetHostPreflightOutput", mock.Anything).Return(successfulPreflightOutput, nil), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil), RunHostPreflights (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("GetHostPreflightOutput", mock.Anything).Return(successfulPreflightOutput, nil) }, expectedErr: false, }, @@ -515,14 +507,10 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostPreflightsFailedBypassed, expectedState: states.StateHostPreflightsSucceeded, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(nil), - pm.On("GetHostPreflightOutput", mock.Anything).Return(successfulPreflightOutput, nil), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil), RunHostPreflights (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("GetHostPreflightOutput", mock.Anything).Return(successfulPreflightOutput, nil) }, expectedErr: false, }, @@ -531,16 +519,12 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostPreflightsFailedBypassed, expectedState: states.StateHostPreflightsFailed, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(nil), - pm.On("GetHostPreflightOutput", mock.Anything).Return(failedPreflightOutput, nil), - st.LinuxPreflightMockStore.On("GetOutput").Return(failedPreflightOutput, nil), - mr.On("ReportHostPreflightsFailed", mock.Anything, failedPreflightOutput).Return(nil), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil), RunHostPreflights (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("GetHostPreflightOutput", mock.Anything).Return(failedPreflightOutput, nil) + st.LinuxPreflightMockStore.On("GetOutput").Return(failedPreflightOutput, nil) + mr.On("ReportHostPreflightsFailed", mock.Anything, failedPreflightOutput).Return(nil) }, expectedErr: false, }, @@ -549,15 +533,11 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostPreflightsFailedBypassed, expectedState: states.StateHostPreflightsFailed, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(nil), - pm.On("GetHostPreflightOutput", mock.Anything).Return(failedPreflightOutput, nil), - st.LinuxPreflightMockStore.On("GetOutput").Return(nil, assert.AnError), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil), RunHostPreflights (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("GetHostPreflightOutput", mock.Anything).Return(failedPreflightOutput, nil) + st.LinuxPreflightMockStore.On("GetOutput").Return(nil, assert.AnError) }, expectedErr: false, }, @@ -566,16 +546,12 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostPreflightsExecutionFailed, expectedState: states.StateHostPreflightsFailed, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(nil), - pm.On("GetHostPreflightOutput", mock.Anything).Return(failedPreflightOutput, nil), - st.LinuxPreflightMockStore.On("GetOutput").Return(failedPreflightOutput, nil), - mr.On("ReportHostPreflightsFailed", mock.Anything, failedPreflightOutput).Return(nil), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil), RunHostPreflights (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("GetHostPreflightOutput", mock.Anything).Return(failedPreflightOutput, nil) + st.LinuxPreflightMockStore.On("GetOutput").Return(failedPreflightOutput, nil) + mr.On("ReportHostPreflightsFailed", mock.Anything, failedPreflightOutput).Return(nil) }, expectedErr: false, }, @@ -584,16 +560,12 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostPreflightsFailed, expectedState: states.StateHostPreflightsFailed, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(nil), - pm.On("GetHostPreflightOutput", mock.Anything).Return(failedPreflightOutput, nil), - st.LinuxPreflightMockStore.On("GetOutput").Return(failedPreflightOutput, nil), - mr.On("ReportHostPreflightsFailed", mock.Anything, failedPreflightOutput).Return(nil), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil), RunHostPreflights (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("GetHostPreflightOutput", mock.Anything).Return(failedPreflightOutput, nil) + st.LinuxPreflightMockStore.On("GetOutput").Return(failedPreflightOutput, nil) + mr.On("ReportHostPreflightsFailed", mock.Anything, failedPreflightOutput).Return(nil) }, expectedErr: false, }, @@ -602,16 +574,12 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostPreflightsFailedBypassed, expectedState: states.StateHostPreflightsFailed, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(nil), - pm.On("GetHostPreflightOutput", mock.Anything).Return(failedPreflightOutput, nil), - st.LinuxPreflightMockStore.On("GetOutput").Return(failedPreflightOutput, nil), - mr.On("ReportHostPreflightsFailed", mock.Anything, failedPreflightOutput).Return(nil), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil), RunHostPreflights (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("GetHostPreflightOutput", mock.Anything).Return(failedPreflightOutput, nil) + st.LinuxPreflightMockStore.On("GetOutput").Return(failedPreflightOutput, nil) + mr.On("ReportHostPreflightsFailed", mock.Anything, failedPreflightOutput).Return(nil) }, expectedErr: false, }, @@ -620,14 +588,10 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostConfigured, expectedState: states.StateHostPreflightsExecutionFailed, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(nil), - pm.On("GetHostPreflightOutput", mock.Anything).Return(nil, assert.AnError), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil), RunHostPreflights (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("GetHostPreflightOutput", mock.Anything).Return(nil, assert.AnError) }, expectedErr: false, }, @@ -636,14 +600,10 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostConfigured, expectedState: states.StateHostPreflightsSucceeded, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(nil), - pm.On("GetHostPreflightOutput", mock.Anything).Return(nil, nil), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil), RunHostPreflights (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("GetHostPreflightOutput", mock.Anything).Return(nil, nil) }, expectedErr: false, }, @@ -652,14 +612,10 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostConfigured, expectedState: states.StateHostPreflightsSucceeded, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(nil), - pm.On("GetHostPreflightOutput", mock.Anything).Return(warnPreflightOutput, nil), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil), RunHostPreflights (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("GetHostPreflightOutput", mock.Anything).Return(warnPreflightOutput, nil) }, expectedErr: false, }, @@ -668,10 +624,10 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostConfigured, expectedState: states.StateHostConfigured, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(nil, errors.New("prepare error")), - ) + // BEFORE: Had to explicitly mock ClearHostPreflightResults even though we don't care + // NOW: Only mock the failure case we're testing + // Default stubs handle: ClearHostPreflightResults (returns nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(nil, errors.New("prepare error")) }, expectedErr: true, }, @@ -680,13 +636,12 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostConfigured, expectedState: states.StateHostPreflightsExecutionFailed, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Return(errors.New("run preflights error")), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { + return expectedHPF == opts.HostPreflightSpec + })).Return(errors.New("run preflights error")) }, expectedErr: false, }, @@ -695,13 +650,12 @@ func TestRunHostPreflights(t *testing.T) { currentState: states.StateHostConfigured, expectedState: states.StateHostPreflightsExecutionFailed, setupMocks: func(pm *preflight.MockHostPreflightManager, rc runtimeconfig.RuntimeConfig, mr *metrics.MockReporter, st *store.MockStore) { - mock.InOrder( - pm.On("ClearHostPreflightResults", mock.Anything).Return(nil), - pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil), - pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { - return expectedHPF == opts.HostPreflightSpec - })).Panic("this is a panic"), - ) + // Only mock what differs from defaults + // Defaults handle: ClearHostPreflightResults (nil) + pm.On("PrepareHostPreflights", t.Context(), rc, mock.Anything).Return(expectedHPF, nil) + pm.On("RunHostPreflights", mock.Anything, rc, mock.MatchedBy(func(opts preflight.RunHostPreflightOptions) bool { + return expectedHPF == opts.HostPreflightSpec + })).Panic("this is a panic") }, expectedErr: false, }, @@ -728,6 +682,8 @@ func TestRunHostPreflights(t *testing.T) { sm := NewStateMachine(WithCurrentState(tt.currentState)) + // Use mock with built-in default stubs via maybeRegisterCall + // Tests only need to set up On() expectations for methods they care about mockPreflightManager := &preflight.MockHostPreflightManager{} mockMetricsReporter := &metrics.MockReporter{} mockStore := &store.MockStore{} @@ -789,6 +745,7 @@ func TestGetHostPreflightStatus(t *testing.T) { { name: "successful get status", setupMock: func(m *preflight.MockHostPreflightManager) { + // Only override the method we care about - defaults handle the rest status := types.Status{ State: types.StateFailed, } @@ -802,6 +759,7 @@ func TestGetHostPreflightStatus(t *testing.T) { { name: "get status error", setupMock: func(m *preflight.MockHostPreflightManager) { + // Only override the method we care about - defaults handle the rest m.On("GetHostPreflightStatus", t.Context()).Return(nil, errors.New("get status error")) }, expectedErr: true, @@ -811,6 +769,7 @@ func TestGetHostPreflightStatus(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // Simple test - only mock the one method we call, defaults handle the rest mockManager := &preflight.MockHostPreflightManager{} tt.setupMock(mockManager) @@ -864,6 +823,7 @@ func TestGetHostPreflightOutput(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // Simple test - only mock the one method we call, defaults handle the rest mockManager := &preflight.MockHostPreflightManager{} tt.setupMock(mockManager) @@ -917,6 +877,7 @@ func TestGetHostPreflightTitles(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // Simple test - only mock the one method we call, defaults handle the rest mockManager := &preflight.MockHostPreflightManager{} tt.setupMock(mockManager) diff --git a/api/internal/managers/linux/preflight/manager_mock.go b/api/internal/managers/linux/preflight/manager_mock.go index 0094529926..3845b92b99 100644 --- a/api/internal/managers/linux/preflight/manager_mock.go +++ b/api/internal/managers/linux/preflight/manager_mock.go @@ -4,62 +4,87 @@ import ( "context" "github.com/replicatedhq/embedded-cluster/api/types" + ecmock "github.com/replicatedhq/embedded-cluster/pkg-new/mock" "github.com/replicatedhq/embedded-cluster/pkg/runtimeconfig" troubleshootv1beta2 "github.com/replicatedhq/troubleshoot/pkg/apis/troubleshoot/v1beta2" - "github.com/stretchr/testify/mock" ) var _ HostPreflightManager = (*MockHostPreflightManager)(nil) -// MockHostPreflightManager is a mock implementation of the HostPreflightManager interface +// MockHostPreflightManager is a mock implementation of the HostPreflightManager interface. +// It embeds ecmock.Mock which provides MaybeRegisterCall for automatic default stub behavior. type MockHostPreflightManager struct { - mock.Mock + ecmock.Mock } // PrepareHostPreflights mocks the PrepareHostPreflights method func (m *MockHostPreflightManager) PrepareHostPreflights(ctx context.Context, rc runtimeconfig.RuntimeConfig, opts PrepareHostPreflightOptions) (*troubleshootv1beta2.HostPreflightSpec, error) { - args := m.Called(ctx, rc, opts) - if args.Get(0) == nil { - return nil, args.Error(1) + if registered, args := m.MaybeRegisterCall(ctx, rc, opts); registered { + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*troubleshootv1beta2.HostPreflightSpec), args.Error(1) } - return args.Get(0).(*troubleshootv1beta2.HostPreflightSpec), args.Error(1) + + // Default stub: return empty spec, no error + return &troubleshootv1beta2.HostPreflightSpec{}, nil } // RunHostPreflights mocks the RunHostPreflights method func (m *MockHostPreflightManager) RunHostPreflights(ctx context.Context, rc runtimeconfig.RuntimeConfig, opts RunHostPreflightOptions) error { - args := m.Called(ctx, rc, opts) - return args.Error(0) + if registered, args := m.MaybeRegisterCall(ctx, rc, opts); registered { + return args.Error(0) + } + + // Default stub: succeed + return nil } // GetHostPreflightStatus mocks the GetHostPreflightStatus method func (m *MockHostPreflightManager) GetHostPreflightStatus(ctx context.Context) (types.Status, error) { - args := m.Called(ctx) - if args.Get(0) == nil { - return types.Status{}, args.Error(1) + if registered, args := m.MaybeRegisterCall(ctx); registered { + if args.Get(0) == nil { + return types.Status{}, args.Error(1) + } + return args.Get(0).(types.Status), args.Error(1) } - return args.Get(0).(types.Status), args.Error(1) + + // Default stub: return running state, no error + return types.Status{State: types.StateRunning}, nil } // GetHostPreflightOutput mocks the GetHostPreflightOutput method func (m *MockHostPreflightManager) GetHostPreflightOutput(ctx context.Context) (*types.PreflightsOutput, error) { - args := m.Called(ctx) - if args.Get(0) == nil { - return nil, args.Error(1) + if registered, args := m.MaybeRegisterCall(ctx); registered { + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*types.PreflightsOutput), args.Error(1) } - return args.Get(0).(*types.PreflightsOutput), args.Error(1) + + // Default stub: return empty successful output, no error + return &types.PreflightsOutput{}, nil } // GetHostPreflightTitles mocks the GetHostPreflightTitles method func (m *MockHostPreflightManager) GetHostPreflightTitles(ctx context.Context) ([]string, error) { - args := m.Called(ctx) - if args.Get(0) == nil { - return nil, args.Error(1) + if registered, args := m.MaybeRegisterCall(ctx); registered { + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]string), args.Error(1) } - return args.Get(0).([]string), args.Error(1) + + // Default stub: return empty list, no error + return []string{}, nil } // ClearHostPreflightResults mocks the ClearHostPreflightResults method func (m *MockHostPreflightManager) ClearHostPreflightResults(ctx context.Context) error { - args := m.Called(ctx) - return args.Error(0) + if registered, args := m.MaybeRegisterCall(ctx); registered { + return args.Error(0) + } + + // Default stub: succeed + return nil } diff --git a/pkg-new/mock/mock.go b/pkg-new/mock/mock.go new file mode 100644 index 0000000000..610eb9d3a9 --- /dev/null +++ b/pkg-new/mock/mock.go @@ -0,0 +1,75 @@ +package mock + +import ( + "runtime" + "strings" + + "github.com/stretchr/testify/mock" +) + +// Mock embeds testify/mock.Mock and provides helper methods for creating mocks +// with default stub behavior. +// +// When embedding this in your mock structs, mock methods can use MaybeRegisterCall +// to automatically provide sensible defaults when tests don't set explicit expectations. +// +// Example usage: +// +// type MockService struct { +// mock.Mock // Replace with: ecmock.Mock +// } +// +// func (m *MockService) GetData(id string) (string, error) { +// if registered, args := m.MaybeRegisterCall(id); registered { +// return args.String(0), args.Error(1) +// } +// // Return default stub when no expectation is set +// return "", nil +// } +type Mock struct { + mock.Mock +} + +// MaybeRegisterCall checks if there are explicit expectations registered for a method. +// It automatically detects the calling method name using runtime.Caller. +// +// If expectations exist, it calls m.MethodCalled() and returns (true, arguments). +// If no expectations exist, it returns (false, nil) and the caller should use default stubs. +// +// This allows mocks to have default "happy path" behavior without requiring +// every test to set up expectations for methods they don't care about. +// +// Usage in mock methods: +// +// func (m *MockService) GetData(id string) (string, error) { +// if registered, args := m.MaybeRegisterCall(id); registered { +// return args.String(0), args.Error(1) +// } +// // Return default stub values +// return "default-data", nil +// } +func (m *Mock) MaybeRegisterCall(args ...interface{}) (bool, mock.Arguments) { + // Get the method name of the caller + pc, _, _, ok := runtime.Caller(1) + if !ok { + panic("Couldn't get the caller information") + } + fullMethodName := runtime.FuncForPC(pc).Name() + + // Extract just the method name from the full path + // fullMethodName is like: "github.com/user/pkg.(*Type).MethodName" + // We need to extract "MethodName" + parts := strings.Split(fullMethodName, ".") + methodName := parts[len(parts)-1] + + // Check if any expectations are registered for this method + for _, call := range m.ExpectedCalls { + if call.Method == methodName { + // Found an explicit expectation - use normal mock behavior + return true, m.MethodCalled(methodName, args...) + } + } + + // No expectations found - caller should use default stub + return false, nil +}