diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 23bd4ba..25bf84b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,10 +1,9 @@ - name: CI on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] jobs: test-example: @@ -12,11 +11,11 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - zig-version: ["0.14.1"] + zig-version: ["0.15.1"] steps: - uses: actions/checkout@v3 - + - name: Install Zig uses: goto-bus-stop/setup-zig@v2 with: @@ -24,6 +23,6 @@ jobs: - name: Check Zig Version run: zig version - + - name: Run tests run: zig build test diff --git a/README.md b/README.md index e334dca..58b8a86 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,32 @@ # Zig Interfaces & Validation -A compile-time interface checker for Zig that enables interface-based design -with comprehensive type checking and detailed error reporting. +A comprehensive interface system for Zig supporting both **compile-time +validation** and **runtime polymorphism** through VTable generation. ## Features -This library provides a way to define and verify interfaces in Zig at compile -time. It supports: +This library provides two complementary approaches to interface-based design in +Zig: -- Type-safe interface definitions with detailed error reporting +**VTable-Based Runtime Polymorphism:** + +- **Automatic VTable wrapper generation** +- Automatic VTable type generation from interface definitions +- Runtime polymorphism with function pointer dispatch +- Return interface types from functions and store in fields + +**Compile-Time Interface Validation:** + +- Zero-overhead generic functions with compile-time type checking +- Detailed error reporting for interface mismatches - Interface embedding (composition) - Complex type validation including structs, enums, arrays, and slices -- Comprehensive compile-time error messages with helpful hints - Flexible error union compatibility with `anyerror` ## Install -Add or update this library as a dependency in your zig project run the following command: +Add or update this library as a dependency in your zig project run the following +command: ```sh zig fetch --save git+https://github.com/nilslice/zig-interface @@ -48,27 +58,39 @@ In the end you can import the `interface` module. For example: const Interface = @import("interface").Interface; const Repository = Interface(.{ - .create = fn(anytype, User) anyerror!u32, - .findById = fn(anytype, u32) anyerror!?User, - .update = fn(anytype, User) anyerror!void, - .delete = fn(anytype, u32) anyerror!void, + .create = fn(User) anyerror!u32, + .findById = fn(u32) anyerror!?User, + .update = fn(User) anyerror!void, + .delete = fn(u32) anyerror!void, }, null); ``` ## Usage -1. Define an interface with required method signatures: +### VTable-Based Runtime Polymorphism + +The primary use case for this library is creating type-erased interface objects +that enable runtime polymorphism. This is ideal for storing different +implementations in collections, returning interface types from functions, or +building plugin systems. + +**1. Define an interface with required method signatures:** ```zig const Repository = Interface(.{ - .create = fn(anytype, User) anyerror!u32, - .findById = fn(anytype, u32) anyerror!?User, - .update = fn(anytype, User) anyerror!void, - .delete = fn(anytype, u32) anyerror!void, + .create = fn(User) anyerror!u32, + .findById = fn(u32) anyerror!?User, + .update = fn(User) anyerror!void, + .delete = fn(u32) anyerror!void, }, null); ``` -2. Implement the interface methods in your type: +> Note: `Interface()` generates a type whose function set declared implicitly +> take an `*anyopaque` self-reference. This saves you from needing to include it +> in the declaration. However, `anyerror` must be included for any fallible +> function, but can be omitted if your function cannot return an error. + +**2. Implement the interface methods in your type:** ```zig const InMemoryRepository = struct { @@ -84,43 +106,108 @@ const InMemoryRepository = struct { return new_user.id; } - // ... other Repository methods + pub fn findById(self: InMemoryRepository, id: u32) !?User { + return self.users.get(id); + } + + pub fn update(self: *InMemoryRepository, user: User) !void { + if (!self.users.contains(user.id)) return error.UserNotFound; + try self.users.put(user.id, user); + } + + pub fn delete(self: *InMemoryRepository, id: u32) !void { + if (!self.users.remove(id)) return error.UserNotFound; + } }; ``` -3. Verify the implementation at compile time: +**3. Use the interface for runtime polymorphism:** ```zig -// In functions that accept interface implementations: +// Create different repository implementations +var in_memory_repo = InMemoryRepository.init(allocator); +var sql_repo = SqlRepository.init(allocator, db_connection); + +// Convert to interface objects +const repo1 = Repository.from(&in_memory_repo); +const repo2 = Repository.from(&sql_repo); + +// Store in heterogeneous collection +var repositories = [_]Repository{ repo1, repo2 }; + +// Use through the interface - runtime polymorphism! +for (repositories) |repo| { + const user = User{ .id = 0, .name = "Alice", .email = "alice@example.com" }; + const id = try repo.vtable.create(repo.ptr, user); + const found = try repo.vtable.findById(repo.ptr, id); +} + +// Return interface types from functions +fn getRepository(use_memory: bool, allocator: Allocator) Repository { + if (use_memory) { + var repo = InMemoryRepository.init(allocator); + return Repository.from(&repo); + } else { + var repo = SqlRepository.init(allocator); + return Repository.from(&repo); + } +} +``` + +### Compile-Time Validation (Alternative Approach) + +For generic functions where you know the concrete type at compile time, you can +use the interface for validation without the VTable overhead: + +```zig +// Generic function that accepts any Repository implementation fn createUser(repo: anytype, name: []const u8, email: []const u8) !User { - comptime Repository.satisfiedBy(@TypeOf(repo)); - // ... rest of implementation + // Validate at compile time that repo implements IRepository + comptime Repository.validation.satisfiedBy(@TypeOf(repo.*)); + + const user = User{ .id = 0, .name = name, .email = email }; + const id = try repo.create(user); + return User{ .id = id, .name = name, .email = email }; } -// Or verify directly: -comptime Repository.satisfiedBy(InMemoryRepository); +// Works with any concrete implementation - no VTable needed +var in_memory = InMemoryRepository.init(allocator); +const user = try createUser(&in_memory, "Alice", "alice@example.com"); ``` ## Interface Embedding -Interfaces can embed other interfaces to combine their requirements: +Interfaces can embed other interfaces to combine their requirements. The +generated VTable will include all methods from embedded interfaces: ```zig const Logger = Interface(.{ - .log = fn(anytype, []const u8) void, - .getLogLevel = fn(anytype) u8, + .log = fn([]const u8) void, + .getLogLevel = fn() u8, }, null); const Metrics = Interface(.{ - .increment = fn(anytype, []const u8) void, - .getValue = fn(anytype, []const u8) u64, + .increment = fn([]const u8) void, + .getValue = fn([]const u8) u64, }, .{ Logger }); // Embeds Logger interface -// Now implements both Metrics and Logger methods -const MonitoredRepository = Interface(.{ - .create = fn(anytype, User) anyerror!u32, - .findById = fn(anytype, u32) anyerror!?User, -}, .{ Metrics }); +// Implementation must provide all methods +const MyMetrics = struct { + log_level: u8, + counters: std.StringHashMap(u64), + + // Logger methods + pub fn log(self: MyMetrics, msg: []const u8) void { ... } + pub fn getLogLevel(self: MyMetrics) u8 { return self.log_level; } + + // Metrics methods + pub fn increment(self: *MyMetrics, name: []const u8) void { ... } + pub fn getValue(self: MyMetrics, name: []const u8) u64 { ... } +}; + +// Use it with auto-generated wrappers: +var my_metrics = MyMetrics{ ... }; +const metrics = Metrics.from(&my_metrics); ``` > Note: you can embed arbitrarily many interfaces! @@ -148,12 +235,12 @@ const BadImpl = struct { ## Complex Types -The interface checker supports complex types including: +The interface checker supports complex types including structs, enums, arrays, +and optionals: ```zig -const ComplexTypes = Interface(.{ +const Processor = Interface(.{ .process = fn( - anytype, struct { config: Config, points: []const DataPoint }, enum { ready, processing, error }, []const struct { @@ -163,4 +250,52 @@ const ComplexTypes = Interface(.{ } ) anyerror!?ProcessingResult, }, null); +... +``` + +## Choosing Between VTable and Compile-Time Approaches + +Both approaches work from the same interface definition and can be used +together: + +| Feature | VTable Runtime Polymorphism | Compile-Time Validation | +| ------------------- | --------------------------------------------------------------- | ---------------------------------- | +| **Use Case** | Heterogeneous collections, plugin systems, returning interfaces | Generic functions, static dispatch | +| **Performance** | Function pointer indirection | Zero overhead (monomorphization) | +| **Binary Size** | Smaller (shared dispatch code) | Larger (per-type instantiation) | +| **Flexibility** | Store in arrays, return from functions | Known types at compile time | +| **Type Visibility** | Type-erased (`*anyopaque`) | Concrete type always known | +| **Method Calls** | `interface.vtable.method(interface.ptr, args)` | Direct: `instance.method(args)` | +| **When to Use** | Need runtime flexibility | Need maximum performance | + +**Example using both:** + +```zig +// Define once +const Repository = Interface(.{ + .save = fn(Data) anyerror!void, +}, null); + +// Use compile-time validation for hot paths +fn processBatch(repo: anytype, items: []const Data) !void { + comptime Repository.validation.satisfiedBy(@TypeOf(repo.*)); + for (items) |item| { + try repo.save(item); // Direct call, can be inlined + } +} + +// Use VTable for plugin registry +const PluginRegistry = struct { + repositories: []Repository, + + fn addPlugin(self: *PluginRegistry, repo: Repository) void { + self.repositories = self.repositories ++ &[_]Repository{repo}; + } + + fn saveToAll(self: PluginRegistry, data: Data) !void { + for (self.repositories) |repo| { + try repo.vtable.save(repo.ptr, data); + } + } +}; ``` diff --git a/build.zig b/build.zig index 2d0fa3b..b17f74d 100644 --- a/build.zig +++ b/build.zig @@ -15,41 +15,117 @@ pub fn build(b: *std.Build) void { // set a preferred release mode, allowing the user to decide how to optimize. const optimize = b.standardOptimizeOption(.{}); - const interface_lib = b.addModule("interface", .{ + const interface_module = b.addModule("interface", .{ .root_source_file = b.path("src/interface.zig"), }); - // Creates a step for unit testing. This only builds the test executable - // but does not run it. - const simple_unit_tests = b.addTest(.{ + // Create test modules + const test_module = b.createModule(.{ + .root_source_file = b.path("src/interface.zig"), + .target = target, + .optimize = optimize, + }); + + // Creates a step for unit testing. + const lib_unit_tests = b.addTest(.{ + .root_module = test_module, + }); + + const run_lib_unit_tests = b.addRunArtifact(lib_unit_tests); + + // Simple test + const simple_test_module = b.createModule(.{ .root_source_file = b.path("test/simple.zig"), .target = target, .optimize = optimize, }); - simple_unit_tests.root_module.addImport("interface", interface_lib); - const run_simple_unit_tests = b.addRunArtifact(simple_unit_tests); + simple_test_module.addImport("interface", interface_module); + + const simple_tests = b.addTest(.{ + .root_module = simple_test_module, + }); - const complex_unit_tests = b.addTest(.{ + const run_simple_tests = b.addRunArtifact(simple_tests); + + // Complex test + const complex_test_module = b.createModule(.{ .root_source_file = b.path("test/complex.zig"), .target = target, .optimize = optimize, }); - complex_unit_tests.root_module.addImport("interface", interface_lib); - const run_complex_unit_tests = b.addRunArtifact(complex_unit_tests); + complex_test_module.addImport("interface", interface_module); + + const complex_tests = b.addTest(.{ + .root_module = complex_test_module, + }); - const embedded_unit_tests = b.addTest(.{ + const run_complex_tests = b.addRunArtifact(complex_tests); + + // Embedded test + const embedded_test_module = b.createModule(.{ .root_source_file = b.path("test/embedded.zig"), .target = target, .optimize = optimize, }); - embedded_unit_tests.root_module.addImport("interface", interface_lib); - const run_embedded_unit_tests = b.addRunArtifact(embedded_unit_tests); + embedded_test_module.addImport("interface", interface_module); + + const embedded_tests = b.addTest(.{ + .root_module = embedded_test_module, + }); + + const run_embedded_tests = b.addRunArtifact(embedded_tests); + + // Vtable test + const vtable_test_module = b.createModule(.{ + .root_source_file = b.path("test/vtable.zig"), + .target = target, + .optimize = optimize, + }); + vtable_test_module.addImport("interface", interface_module); + + const vtable_tests = b.addTest(.{ + .root_module = vtable_test_module, + }); + + const run_vtable_tests = b.addRunArtifact(vtable_tests); + + // Collections test + const collections_test_module = b.createModule(.{ + .root_source_file = b.path("test/collections.zig"), + .target = target, + .optimize = optimize, + }); + collections_test_module.addImport("interface", interface_module); + + const collections_tests = b.addTest(.{ + .root_module = collections_test_module, + }); + + const run_collections_tests = b.addRunArtifact(collections_tests); + + // Inference test + const inference_test_module = b.createModule(.{ + .root_source_file = b.path("test/inference.zig"), + .target = target, + .optimize = optimize, + }); + inference_test_module.addImport("interface", interface_module); + + const inference_tests = b.addTest(.{ + .root_module = inference_test_module, + }); + + const run_inference_tests = b.addRunArtifact(inference_tests); // Similar to creating the run step earlier, this exposes a `test` step to // the `zig build --help` menu, providing a way for the user to request // running the unit tests. const test_step = b.step("test", "Run unit tests"); - test_step.dependOn(&run_simple_unit_tests.step); - test_step.dependOn(&run_complex_unit_tests.step); - test_step.dependOn(&run_embedded_unit_tests.step); + test_step.dependOn(&run_lib_unit_tests.step); + test_step.dependOn(&run_simple_tests.step); + test_step.dependOn(&run_complex_tests.step); + test_step.dependOn(&run_embedded_tests.step); + test_step.dependOn(&run_vtable_tests.step); + test_step.dependOn(&run_collections_tests.step); + test_step.dependOn(&run_inference_tests.step); } diff --git a/src/interface.zig b/src/interface.zig index 84a9d63..0ee0ba7 100644 --- a/src/interface.zig +++ b/src/interface.zig @@ -1,4 +1,175 @@ const std = @import("std"); +const builtin = @import("builtin"); + +pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { + const embedded_interfaces = switch (@typeInfo(@TypeOf(embedded))) { + .null => embedded, + .@"struct" => |s| if (s.is_tuple) embedded else .{embedded}, + else => .{embedded}, + }; + + const has_embeds = @TypeOf(embedded_interfaces) != @TypeOf(null); + + // Generate VTable type with function pointers + const VTableType = generateVTableType(methods, embedded_interfaces, has_embeds); + + // Create the validation namespace + const ValidationNamespace = CreateValidationNamespace(methods, embedded_interfaces, has_embeds); + + // Return the VTable-based interface type directly + return struct { + ptr: *anyopaque, + vtable: *const VTableType, + + pub const VTable = VTableType; + pub const validation = ValidationNamespace; + + /// Creates an interface wrapper from an implementation pointer and vtable. + pub fn init(impl: anytype, vtable_ptr: *const VTableType) @This() { + const ImplPtr = @TypeOf(impl); + const impl_type_info = @typeInfo(ImplPtr); + + // Verify it's a pointer + if (impl_type_info != .pointer) { + @compileError("init() requires a pointer to an implementation, got: " ++ @typeName(ImplPtr)); + } + + const ImplType = impl_type_info.pointer.child; + + // Validate that the type satisfies the interface at compile time + comptime validation.satisfiedBy(ImplType); + + return .{ + .ptr = impl, + .vtable = vtable_ptr, + }; + } + + /// Automatically generates VTable wrappers and creates an interface wrapper. + pub fn from(impl: anytype) @This() { + const ImplPtr = @TypeOf(impl); + const impl_type_info = @typeInfo(ImplPtr); + + // Verify it's a pointer + if (impl_type_info != .pointer) { + @compileError("from() requires a pointer to an implementation, got: " ++ @typeName(ImplPtr)); + } + + const ImplType = impl_type_info.pointer.child; + + // Validate that the type satisfies the interface at compile time + comptime validation.satisfiedBy(ImplType); + + // Generate a unique wrapper struct with static VTable for this ImplType + const gen = struct { + fn generateWrapperForField(comptime T: type, comptime vtable_field: std.builtin.Type.StructField) *const anyopaque { + // Extract function signature from vtable field + const fn_ptr_info = @typeInfo(vtable_field.type); + const fn_info = @typeInfo(fn_ptr_info.pointer.child).@"fn"; + const method_name = vtable_field.name; + + // Check if the implementation method expects *T or T + const impl_method_info = @typeInfo(@TypeOf(@field(T, method_name))); + const impl_fn_info = impl_method_info.@"fn"; + const first_param_info = @typeInfo(impl_fn_info.params[0].type.?); + const expects_pointer = first_param_info == .pointer; + + // Generate wrapper matching the exact signature + const param_count = fn_info.params.len; + if (param_count < 1 or param_count > 5) { + @compileError("Method '" ++ method_name ++ "' has too many parameters. Only 1-5 parameters (including self pointer) are supported."); + } + + // Create wrapper with exact parameter types from VTable signature + if (expects_pointer) { + return switch (param_count) { + 1 => &struct { + fn wrapper(ptr: *anyopaque) callconv(fn_info.calling_convention) fn_info.return_type.? { + const self: *T = @ptrCast(@alignCast(ptr)); + return @field(T, method_name)(self); + } + }.wrapper, + 2 => &struct { + fn wrapper(ptr: *anyopaque, p1: fn_info.params[1].type.?) callconv(fn_info.calling_convention) fn_info.return_type.? { + const self: *T = @ptrCast(@alignCast(ptr)); + return @field(T, method_name)(self, p1); + } + }.wrapper, + 3 => &struct { + fn wrapper(ptr: *anyopaque, p1: fn_info.params[1].type.?, p2: fn_info.params[2].type.?) callconv(fn_info.calling_convention) fn_info.return_type.? { + const self: *T = @ptrCast(@alignCast(ptr)); + return @field(T, method_name)(self, p1, p2); + } + }.wrapper, + 4 => &struct { + fn wrapper(ptr: *anyopaque, p1: fn_info.params[1].type.?, p2: fn_info.params[2].type.?, p3: fn_info.params[3].type.?) callconv(fn_info.calling_convention) fn_info.return_type.? { + const self: *T = @ptrCast(@alignCast(ptr)); + return @field(T, method_name)(self, p1, p2, p3); + } + }.wrapper, + 5 => &struct { + fn wrapper(ptr: *anyopaque, p1: fn_info.params[1].type.?, p2: fn_info.params[2].type.?, p3: fn_info.params[3].type.?, p4: fn_info.params[4].type.?) callconv(fn_info.calling_convention) fn_info.return_type.? { + const self: *T = @ptrCast(@alignCast(ptr)); + return @field(T, method_name)(self, p1, p2, p3, p4); + } + }.wrapper, + else => unreachable, + }; + } else { + return switch (param_count) { + 1 => &struct { + fn wrapper(ptr: *anyopaque) callconv(fn_info.calling_convention) fn_info.return_type.? { + const self: *T = @ptrCast(@alignCast(ptr)); + return @field(T, method_name)(self.*); + } + }.wrapper, + 2 => &struct { + fn wrapper(ptr: *anyopaque, p1: fn_info.params[1].type.?) callconv(fn_info.calling_convention) fn_info.return_type.? { + const self: *T = @ptrCast(@alignCast(ptr)); + return @field(T, method_name)(self.*, p1); + } + }.wrapper, + 3 => &struct { + fn wrapper(ptr: *anyopaque, p1: fn_info.params[1].type.?, p2: fn_info.params[2].type.?) callconv(fn_info.calling_convention) fn_info.return_type.? { + const self: *T = @ptrCast(@alignCast(ptr)); + return @field(T, method_name)(self.*, p1, p2); + } + }.wrapper, + 4 => &struct { + fn wrapper(ptr: *anyopaque, p1: fn_info.params[1].type.?, p2: fn_info.params[2].type.?, p3: fn_info.params[3].type.?) callconv(fn_info.calling_convention) fn_info.return_type.? { + const self: *T = @ptrCast(@alignCast(ptr)); + return @field(T, method_name)(self.*, p1, p2, p3); + } + }.wrapper, + 5 => &struct { + fn wrapper(ptr: *anyopaque, p1: fn_info.params[1].type.?, p2: fn_info.params[2].type.?, p3: fn_info.params[3].type.?, p4: fn_info.params[4].type.?) callconv(fn_info.calling_convention) fn_info.return_type.? { + const self: *T = @ptrCast(@alignCast(ptr)); + return @field(T, method_name)(self.*, p1, p2, p3, p4); + } + }.wrapper, + else => unreachable, + }; + } + } + + const vtable: VTableType = blk: { + var result: VTableType = undefined; + // Iterate over all VTable fields (includes embedded interface methods) + for (std.meta.fields(VTableType)) |vtable_field| { + const wrapper_ptr = generateWrapperForField(ImplType, vtable_field); + @field(result, vtable_field.name) = @ptrCast(@alignCast(wrapper_ptr)); + } + break :blk result; + }; + }; + + return .{ + .ptr = impl, + .vtable = &gen.vtable, + }; + } + }; +} /// Compares two types structurally to determine if they're compatible fn isTypeCompatible(comptime T1: type, comptime T2: type) bool { @@ -127,67 +298,109 @@ fn formatTypeMismatch( return result; } -/// Creates a verifiable interface type that can be used to define method requirements -/// for other types. Interfaces can embed other interfaces, combining their requirements. -/// -/// The interface consists of method signatures that implementing types must match exactly. -/// Method signatures must use `anytype` for the self parameter to allow any implementing type. -/// -/// Supports: -/// - Complex types (structs, enums, arrays, slices) -/// - Error unions with specific or `anyerror` -/// - Optional types and comptime checking -/// - Interface embedding (combining multiple interfaces) -/// - Detailed error reporting for mismatched implementations -/// -/// Params: -/// methods: A struct of function signatures that define the interface -/// embedded: A tuple of other interfaces to embed, or null for no embedding -/// -/// Example: -/// ``` -/// const Writer = Interface(.{ -/// .writeAll = fn(anytype, []const u8) anyerror!void, -/// }, null); -/// -/// const Logger = Interface(.{ -/// .log = fn(anytype, []const u8) void, -/// }, .{ Writer }); // Embeds Writer interface -/// -/// // Usage in functions: -/// fn write(w: anytype, data: []const u8) !void { -/// comptime Writer.satisfiedBy(@TypeOf(w)); -/// try w.writeAll(data); -/// } -/// ``` -/// -/// Common incompatibilities reported: -/// - Missing required methods -/// - Wrong parameter counts or types -/// - Incorrect return types -/// - Method name conflicts in embedded interfaces -/// - Non-const slices where const is required -/// -pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { - const embedded_interfaces = switch (@typeInfo(@TypeOf(embedded))) { - .null => embedded, - .@"struct" => |s| if (s.is_tuple) embedded else .{embedded}, - else => .{embedded}, - }; +fn generateVTableType(comptime methods: anytype, comptime embedded_interfaces: anytype, comptime has_embeds: bool) type { + comptime { + // Build array of struct fields for the VTable + var fields: []const std.builtin.Type.StructField = &.{}; + + // Helper function to add a method to the VTable + const addMethod = struct { + fn add(method_field: std.builtin.Type.StructField, method_fn: anytype, field_list: []const std.builtin.Type.StructField) []const std.builtin.Type.StructField { + const fn_info = @typeInfo(method_fn).@"fn"; + + // Build parameter list: insert *anyopaque as first param (implicit self) + var params: [fn_info.params.len + 1]std.builtin.Type.Fn.Param = undefined; + params[0] = .{ + .is_generic = false, + .is_noalias = false, + .type = *anyopaque, + }; + + // Copy all interface parameters after the implicit self + for (fn_info.params, 1..) |param, i| { + params[i] = param; + } - // Handle the case where null is passed for embedded_interfaces - const has_embeds = @TypeOf(embedded_interfaces) != @TypeOf(null); + // Create function pointer type + const FnType = @Type(.{ + .@"fn" = .{ + .calling_convention = fn_info.calling_convention, + .is_generic = false, + .is_var_args = false, + .return_type = fn_info.return_type, + .params = ¶ms, + }, + }); + + const FnPtrType = *const FnType; + + // Add field to VTable + return field_list ++ &[_]std.builtin.Type.StructField{.{ + .name = method_field.name, + .type = FnPtrType, + .default_value_ptr = null, + .is_comptime = false, + .alignment = @alignOf(FnPtrType), + }}; + } + }.add; + + // Helper to check if a field name already exists + const hasField = struct { + fn check(field_name: []const u8, field_list: []const std.builtin.Type.StructField) bool { + for (field_list) |field| { + if (std.mem.eql(u8, field.name, field_name)) { + return true; + } + } + return false; + } + }.check; + + // Add methods from embedded interfaces first + if (has_embeds) { + const Embeds = @TypeOf(embedded_interfaces); + for (std.meta.fields(Embeds)) |embed_field| { + const embed = @field(embedded_interfaces, embed_field.name); + // Recursively get the VTable type from the embedded interface + const EmbedVTable = embed.VTable; + for (std.meta.fields(EmbedVTable)) |vtable_field| { + // Skip if we already have this field (indicates a conflict that validation should catch) + if (!hasField(vtable_field.name, fields)) { + fields = fields ++ &[_]std.builtin.Type.StructField{vtable_field}; + } + } + } + } - return struct { - const Self = @This(); - const name = @typeName(Self); + // Add methods from primary interface + for (std.meta.fields(@TypeOf(methods))) |method_field| { + const method_fn = @field(methods, method_field.name); + // Only add if not already present from embedded interfaces + if (!hasField(method_field.name, fields)) { + fields = addMethod(method_field, method_fn, fields); + } + } - // Store these at the type level so they're accessible to helper functions + // Create the VTable struct type + return @Type(.{ + .@"struct" = .{ + .layout = .auto, + .fields = fields, + .decls = &.{}, + .is_tuple = false, + }, + }); + } +} + +fn CreateValidationNamespace(comptime methods: anytype, comptime embedded_interfaces: anytype, comptime has_embeds: bool) type { + return struct { const Methods = @TypeOf(methods); const Embeds = @TypeOf(embedded_interfaces); /// Represents all possible interface implementation problems - const Incompatibility = union(enum) { + pub const Incompatibility = union(enum) { missing_method: []const u8, wrong_param_count: struct { method: []const u8, @@ -225,7 +438,7 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { if (has_embeds) { for (std.meta.fields(Embeds)) |embed_field| { const embed = @field(embedded_interfaces, embed_field.name); - method_count += embed.collectMethodNames().len; + method_count += embed.validation.collectMethodNames().len; } } @@ -243,7 +456,7 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { if (has_embeds) { for (std.meta.fields(Embeds)) |embed_field| { const embed = @field(embedded_interfaces, embed_field.name); - const embed_methods = embed.collectMethodNames(); + const embed_methods = embed.validation.collectMethodNames(); @memcpy(names[index..][0..embed_methods.len], embed_methods); index += embed_methods.len; } @@ -267,7 +480,7 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { if (has_embeds) { for (std.meta.fields(Embeds)) |embed_field| { const embed = @field(embedded_interfaces, embed_field.name); - if (embed.hasMethod(method_name)) { + if (embed.validation.hasMethod(method_name)) { interface_count += 1; } } @@ -280,7 +493,6 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { // Add primary interface if (@hasDecl(Methods, method_name)) { - interfaces[index] = name; index += 1; } @@ -288,7 +500,7 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { if (has_embeds) { for (std.meta.fields(Embeds)) |embed_field| { const embed = @field(embedded_interfaces, embed_field.name); - if (embed.hasMethod(method_name)) { + if (embed.validation.hasMethod(method_name)) { interfaces[index] = @typeName(@TypeOf(embed)); index += 1; } @@ -300,7 +512,7 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { } /// Checks if this interface has a specific method - fn hasMethod(comptime method_name: []const u8) bool { + pub fn hasMethod(comptime method_name: []const u8) bool { comptime { // Check primary interface if (@hasDecl(Methods, method_name)) { @@ -311,7 +523,7 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { if (has_embeds) { for (std.meta.fields(Embeds)) |embed_field| { const embed = @field(embedded_interfaces, embed_field.name); - if (embed.hasMethod(method_name)) { + if (embed.validation.hasMethod(method_name)) { return true; } } @@ -329,19 +541,17 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { return Expected == Actual; } - if (exp_info.error_union.error_set == anyerror) { - return exp_info.error_union.payload == act_info.error_union.payload; - } - return Expected == Actual; + // Any error union in the interface accepts any error set in the implementation + return exp_info.error_union.payload == act_info.error_union.payload; } - pub fn incompatibilities(comptime Type: type) []const Incompatibility { + pub fn incompatibilities(comptime ImplType: type) []const Incompatibility { comptime { var problems: []const Incompatibility = &.{}; // First check for method ambiguity across all interfaces - for (Self.collectMethodNames()) |method_name| { - if (Self.findMethodConflicts(method_name)) |conflicting_interfaces| { + for (collectMethodNames()) |method_name| { + if (findMethodConflicts(method_name)) |conflicting_interfaces| { problems = problems ++ &[_]Incompatibility{.{ .ambiguous_method = .{ .method = method_name, @@ -356,29 +566,33 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { // Check primary interface methods for (std.meta.fields(@TypeOf(methods))) |field| { - if (!@hasDecl(Type, field.name)) { + if (!@hasDecl(ImplType, field.name)) { problems = problems ++ &[_]Incompatibility{.{ .missing_method = field.name, }}; continue; } - const impl_fn = @TypeOf(@field(Type, field.name)); + const impl_fn = @TypeOf(@field(ImplType, field.name)); const expected_fn = @field(methods, field.name); const impl_info = @typeInfo(impl_fn).@"fn"; const expected_info = @typeInfo(expected_fn).@"fn"; - if (impl_info.params.len != expected_info.params.len) { + // Implementation has self parameter, interface signature doesn't + const expected_param_count = expected_info.params.len + 1; + + if (impl_info.params.len != expected_param_count) { problems = problems ++ &[_]Incompatibility{.{ .wrong_param_count = .{ .method = field.name, - .expected = expected_info.params.len, + .expected = expected_param_count, .got = impl_info.params.len, }, }}; } else { - for (impl_info.params[1..], expected_info.params[1..], 0..) |impl_param, expected_param, i| { + // Compare impl params[1..] (skip self) with interface params[0..] + for (impl_info.params[1..], expected_info.params, 0..) |impl_param, expected_param, i| { if (!isTypeCompatible(impl_param.type.?, expected_param.type.?)) { problems = problems ++ &[_]Incompatibility{.{ .param_type_mismatch = .{ @@ -407,7 +621,7 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { if (has_embeds) { for (std.meta.fields(@TypeOf(embedded_interfaces))) |embed_field| { const embed = @field(embedded_interfaces, embed_field.name); - const embed_problems = embed.incompatibilities(Type); + const embed_problems = embed.validation.incompatibilities(ImplType); problems = problems ++ embed_problems; } } @@ -417,7 +631,7 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { } fn formatIncompatibility(incompatibility: Incompatibility) []const u8 { - const indent = " └─ "; + const indent = if (builtin.os.tag == .windows) " \\- " else " └─ "; return switch (incompatibility) { .missing_method => |method| std.fmt.comptimePrint("Missing required method: {s}\n{s}Add the method with the correct signature to your implementation", .{ method, indent }), @@ -453,17 +667,14 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { }; } - pub fn satisfiedBy(comptime Type: type) void { + pub fn satisfiedBy(comptime ImplType: type) void { comptime { - const problems = incompatibilities(Type); + const problems = incompatibilities(ImplType); if (problems.len > 0) { - const title = "Type '{s}' does not implement interface '{s}':\n"; + const title = "Type '{s}' does not implement the expected interface(s). To fix:\n"; // First compute the total size needed for our error message - var total_len: usize = std.fmt.count(title, .{ - @typeName(Type), - name, - }); + var total_len: usize = std.fmt.count(title, .{@typeName(ImplType)}); // Add space for each problem's length for (1.., problems) |i, problem| { @@ -474,10 +685,7 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { var errors: [total_len]u8 = undefined; var written: usize = 0; - written += (std.fmt.bufPrint(errors[written..], title, .{ - @typeName(Type), - name, - }) catch unreachable).len; + written += (std.fmt.bufPrint(errors[written..], title, .{@typeName(ImplType)}) catch unreachable).len; // Write each problem for (1.., problems) |i, problem| { @@ -490,54 +698,3 @@ pub fn Interface(comptime methods: anytype, comptime embedded: anytype) type { } }; } - -test "expected usage of embedded interfaces" { - const Logger = Interface(.{ - .log = fn (anytype, []const u8) void, - }, .{}); - - const Writer = Interface(.{ - .write = fn (anytype, []const u8) anyerror!void, - }, .{Logger}); - - const Implementation = struct { - pub fn write(self: @This(), data: []const u8) !void { - _ = self; - _ = data; - } - - pub fn log(self: @This(), msg: []const u8) void { - _ = self; - _ = msg; - } - }; - - comptime Writer.satisfiedBy(Implementation); - - try std.testing.expect(Writer.incompatibilities(Implementation).len == 0); -} - -test "expected failure case of embedded interfaces" { - const Logger = Interface(.{ - .log = fn (anytype, []const u8, u8) void, - .missing = fn (anytype) void, - }, .{}); - - const Writer = Interface(.{ - .write = fn (anytype, []const u8) anyerror!void, - }, .{Logger}); - - const Implementation = struct { - pub fn write(self: @This(), data: []const u8) !void { - _ = self; - _ = data; - } - - pub fn log(self: @This(), msg: []const u8) void { - _ = self; - _ = msg; - } - }; - - try std.testing.expect(Writer.incompatibilities(Implementation).len == 2); -} diff --git a/test/collections.zig b/test/collections.zig new file mode 100644 index 0000000..9c828dd --- /dev/null +++ b/test/collections.zig @@ -0,0 +1,267 @@ +const std = @import("std"); +const Interface = @import("interface").Interface; + +// Define State interface for a state machine +// Generate VTable-based runtime type +const State = Interface(.{ + .onEnter = fn () void, + .onExit = fn () void, + .update = fn (f32) void, +}, null); + +// Menu state implementation +const MenuState = struct { + name: []const u8, + entered: bool = false, + exited: bool = false, + updates: u32 = 0, + + pub fn onEnter(self: *MenuState) void { + self.entered = true; + } + + pub fn onExit(self: *MenuState) void { + self.exited = true; + } + + pub fn update(self: *MenuState, delta: f32) void { + _ = delta; + self.updates += 1; + } +}; + +// Gameplay state implementation +const GameplayState = struct { + score: u32 = 0, + time_elapsed: f32 = 0.0, + + pub fn onEnter(self: *GameplayState) void { + self.score = 0; + self.time_elapsed = 0.0; + } + + pub fn onExit(self: *GameplayState) void { + _ = self; + } + + pub fn update(self: *GameplayState, delta: f32) void { + self.time_elapsed += delta; + self.score += 10; + } +}; + +// Pause state implementation +const PauseState = struct { + paused_at: f32 = 0.0, + + pub fn onEnter(self: *PauseState) void { + _ = self; + } + + pub fn onExit(self: *PauseState) void { + _ = self; + } + + pub fn update(self: *PauseState, delta: f32) void { + self.paused_at += delta; + } +}; + +// State manager with stack of interface objects +const StateManager = struct { + stack: std.ArrayList(State), + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator) StateManager { + return .{ + .stack = std.ArrayList(State){}, + .allocator = allocator, + }; + } + + pub fn deinit(self: *StateManager) void { + // Exit all states before cleanup + while (self.stack.items.len > 0) { + self.popState(); + } + self.stack.deinit(self.allocator); + } + + pub fn pushState(self: *StateManager, state: State) !void { + try self.stack.append(self.allocator, state); + // Call onEnter on the new state + const current = &self.stack.items[self.stack.items.len - 1]; + current.vtable.onEnter(current.ptr); + } + + pub fn popState(self: *StateManager) void { + if (self.stack.items.len > 0) { + const current = &self.stack.items[self.stack.items.len - 1]; + current.vtable.onExit(current.ptr); + _ = self.stack.pop(); + } + } + + pub fn update(self: *StateManager, delta: f32) void { + if (self.stack.items.len > 0) { + const current = &self.stack.items[self.stack.items.len - 1]; + current.vtable.update(current.ptr, delta); + } + } + + pub fn currentStateCount(self: StateManager) usize { + return self.stack.items.len; + } +}; + +test "state machine basic push and pop" { + var menu = MenuState{ .name = "Main Menu" }; + var gameplay = GameplayState{}; + + var manager = StateManager.init(std.testing.allocator); + defer manager.deinit(); + + // Initially empty + try std.testing.expectEqual(@as(usize, 0), manager.currentStateCount()); + + // Push menu state - auto-generated wrappers! + try manager.pushState(State.from(&menu)); + try std.testing.expectEqual(@as(usize, 1), manager.currentStateCount()); + try std.testing.expect(menu.entered); + try std.testing.expect(!menu.exited); + + // Push gameplay state - auto-generated wrappers! + try manager.pushState(State.from(&gameplay)); + try std.testing.expectEqual(@as(usize, 2), manager.currentStateCount()); + + // Pop gameplay + manager.popState(); + try std.testing.expectEqual(@as(usize, 1), manager.currentStateCount()); + + // Pop menu + manager.popState(); + try std.testing.expectEqual(@as(usize, 0), manager.currentStateCount()); + try std.testing.expect(menu.exited); +} + +test "state machine update propagation" { + var menu = MenuState{ .name = "Main Menu" }; + var gameplay = GameplayState{}; + + var manager = StateManager.init(std.testing.allocator); + defer manager.deinit(); + + // Push menu and update it + try manager.pushState(State.from(&menu)); + try std.testing.expectEqual(@as(u32, 0), menu.updates); + + manager.update(0.016); + try std.testing.expectEqual(@as(u32, 1), menu.updates); + + manager.update(0.016); + try std.testing.expectEqual(@as(u32, 2), menu.updates); + + // Push gameplay - it becomes the active state + try manager.pushState(State.from(&gameplay)); + try std.testing.expectEqual(@as(u32, 0), gameplay.score); + + manager.update(0.016); + // Gameplay updated, menu not updated + try std.testing.expectEqual(@as(u32, 10), gameplay.score); + try std.testing.expectEqual(@as(u32, 2), menu.updates); + + manager.update(0.016); + try std.testing.expectEqual(@as(u32, 20), gameplay.score); + try std.testing.expectEqual(@as(u32, 2), menu.updates); +} + +test "state machine complex transitions" { + var menu = MenuState{ .name = "Main Menu" }; + var gameplay = GameplayState{}; + var pause = PauseState{}; + + var manager = StateManager.init(std.testing.allocator); + defer manager.deinit(); + + // Menu -> Gameplay -> Pause -> Gameplay -> Menu + try manager.pushState(State.from(&menu)); + manager.update(0.016); + try std.testing.expectEqual(@as(u32, 1), menu.updates); + + try manager.pushState(State.from(&gameplay)); + manager.update(0.016); + manager.update(0.016); + try std.testing.expectEqual(@as(u32, 20), gameplay.score); + + try manager.pushState(State.from(&pause)); + manager.update(0.016); + try std.testing.expectApproxEqAbs(@as(f32, 0.016), pause.paused_at, 0.001); + // Gameplay shouldn't update while paused + try std.testing.expectEqual(@as(u32, 20), gameplay.score); + + // Unpause + manager.popState(); + manager.update(0.016); + try std.testing.expectEqual(@as(u32, 30), gameplay.score); + + // Back to menu + manager.popState(); + manager.update(0.016); + try std.testing.expectEqual(@as(u32, 2), menu.updates); +} + +test "heterogeneous collection of states" { + var menu1 = MenuState{ .name = "Main Menu" }; + var menu2 = MenuState{ .name = "Options Menu" }; + var gameplay1 = GameplayState{}; + var gameplay2 = GameplayState{}; + var pause = PauseState{}; + + // Create an array of different state types + const states = [_]State{ + State.from(&menu1), + State.from(&gameplay1), + State.from(&pause), + State.from(&menu2), + State.from(&gameplay2), + }; + + // All states can be called through the same interface + for (states) |state| { + state.vtable.onEnter(state.ptr); + state.vtable.update(state.ptr, 0.016); + state.vtable.onExit(state.ptr); + } + + // Verify they were all called + try std.testing.expect(menu1.entered); + try std.testing.expect(menu1.exited); + try std.testing.expect(menu2.entered); + try std.testing.expect(menu2.exited); + try std.testing.expectEqual(@as(u32, 1), menu1.updates); + try std.testing.expectEqual(@as(u32, 1), menu2.updates); + try std.testing.expectEqual(@as(u32, 10), gameplay1.score); + try std.testing.expectEqual(@as(u32, 10), gameplay2.score); +} + +test "state manager with multiple instance types" { + var menu = MenuState{ .name = "Main" }; + var gameplay = GameplayState{}; + var pause = PauseState{}; + + var manager = StateManager.init(std.testing.allocator); + defer manager.deinit(); + + // Push different types in sequence + try manager.pushState(State.from(&menu)); + try manager.pushState(State.from(&gameplay)); + try manager.pushState(State.from(&pause)); + + try std.testing.expectEqual(@as(usize, 3), manager.currentStateCount()); + + // Update only affects top of stack + manager.update(1.0); + try std.testing.expectApproxEqAbs(@as(f32, 1.0), pause.paused_at, 0.001); + try std.testing.expectEqual(@as(u32, 0), gameplay.score); + try std.testing.expectEqual(@as(u32, 0), menu.updates); +} diff --git a/test/complex.zig b/test/complex.zig index 14e2011..0394e2e 100644 --- a/test/complex.zig +++ b/test/complex.zig @@ -2,8 +2,8 @@ const std = @import("std"); const Interface = @import("interface").Interface; test "complex type support" { - const ComplexTypes = Interface(.{ - .complexMethod = fn (anytype, struct { a: []const u8, b: ?i32 }, enum { a, b, c }, []const struct { x: u32, y: ?[]const u8 }) anyerror!void, + const IComplexTypes = Interface(.{ + .complexMethod = fn (struct { a: []const u8, b: ?i32 }, enum { a, b, c }, []const struct { x: u32, y: ?[]const u8 }) anyerror!void, }, null); // Correct implementation @@ -22,7 +22,7 @@ test "complex type support" { }; // Should compile without error - comptime ComplexTypes.satisfiedBy(GoodImpl); + comptime IComplexTypes.validation.satisfiedBy(GoodImpl); // Bad implementation - mismatched struct field type const BadImpl1 = struct { @@ -69,9 +69,9 @@ test "complex type support" { } }; - try std.testing.expect(ComplexTypes.incompatibilities(BadImpl1).len > 0); - try std.testing.expect(ComplexTypes.incompatibilities(BadImpl2).len > 0); - try std.testing.expect(ComplexTypes.incompatibilities(BadImpl3).len > 0); + try std.testing.expect(IComplexTypes.validation.incompatibilities(BadImpl1).len > 0); + try std.testing.expect(IComplexTypes.validation.incompatibilities(BadImpl2).len > 0); + try std.testing.expect(IComplexTypes.validation.incompatibilities(BadImpl3).len > 0); } test "complex type support with embedding" { @@ -107,26 +107,26 @@ test "complex type support with embedding" { }; // Base interfaces with complex types - const Configurable = Interface(.{ - .configure = fn (anytype, Config) anyerror!void, - .getConfig = fn (anytype) Config, + const IConfigurable = Interface(.{ + .configure = fn (Config) anyerror!void, + .getConfig = fn () Config, }, null); - const StatusProvider = Interface(.{ - .getStatus = fn (anytype) Status, - .setStatus = fn (anytype, Status) anyerror!void, + const IStatusProvider = Interface(.{ + .getStatus = fn () Status, + .setStatus = fn (Status) anyerror!void, }, null); - const DataHandler = Interface(.{ - .processData = fn (anytype, []const DataPoint) anyerror!void, - .getLastPoint = fn (anytype) ?DataPoint, + const IDataHandler = Interface(.{ + .processData = fn ([]const DataPoint) anyerror!void, + .getLastPoint = fn () ?DataPoint, }, null); // Complex interface that embeds all the above and adds its own complex methods - const ComplexTypes = Interface(.{ - .complexMethod = fn (anytype, Config, Status, []const DataPoint) anyerror!void, - .superComplex = fn (anytype, ProcessingInput, ProcessingMode, []const HistoryEntry) anyerror!?ProcessingResult, - }, .{ Configurable, StatusProvider, DataHandler }); + const IComplexTypes = Interface(.{ + .complexMethod = fn (Config, Status, []const DataPoint) anyerror!void, + .superComplex = fn (ProcessingInput, ProcessingMode, []const HistoryEntry) anyerror!?ProcessingResult, + }, .{ IConfigurable, IStatusProvider, IDataHandler }); // Correct implementation const GoodImpl = struct { @@ -190,10 +190,10 @@ test "complex type support with embedding" { }; // Should compile without error - comptime ComplexTypes.satisfiedBy(GoodImpl); - comptime Configurable.satisfiedBy(GoodImpl); - comptime StatusProvider.satisfiedBy(GoodImpl); - comptime DataHandler.satisfiedBy(GoodImpl); + comptime IComplexTypes.validation.satisfiedBy(GoodImpl); + comptime IConfigurable.validation.satisfiedBy(GoodImpl); + comptime IStatusProvider.validation.satisfiedBy(GoodImpl); + comptime IDataHandler.validation.satisfiedBy(GoodImpl); // Bad implementation - missing embedded interface methods const BadImpl1 = struct { @@ -282,6 +282,6 @@ test "complex type support with embedding" { }; // Test that bad implementations are caught - try std.testing.expect(ComplexTypes.incompatibilities(BadImpl1).len > 0); - try std.testing.expect(ComplexTypes.incompatibilities(BadImpl2).len > 0); + try std.testing.expect(IComplexTypes.validation.incompatibilities(BadImpl1).len > 0); + try std.testing.expect(IComplexTypes.validation.incompatibilities(BadImpl2).len > 0); } diff --git a/test/embedded.zig b/test/embedded.zig index 3de3427..770e54a 100644 --- a/test/embedded.zig +++ b/test/embedded.zig @@ -10,21 +10,21 @@ const User = struct { test "interface embedding" { // Base interfaces const Logger = Interface(.{ - .log = fn (anytype, []const u8) void, - .getLogLevel = fn (anytype) u8, + .log = fn ([]const u8) void, + .getLogLevel = fn () u8, }, null); const Metrics = Interface(.{ - .increment = fn (anytype, []const u8) void, - .getValue = fn (anytype, []const u8) u64, + .increment = fn ([]const u8) void, + .getValue = fn ([]const u8) u64, }, .{Logger}); // Complex interface that embeds both Logger and Metrics const MonitoredRepository = Interface(.{ - .create = fn (anytype, User) anyerror!u32, - .findById = fn (anytype, u32) anyerror!?User, - .update = fn (anytype, User) anyerror!void, - .delete = fn (anytype, u32) anyerror!void, + .create = fn (User) anyerror!u32, + .findById = fn (u32) anyerror!?User, + .update = fn (User) anyerror!void, + .delete = fn (u32) anyerror!void, }, .{Metrics}); // Implementation that satisfies all interfaces @@ -114,9 +114,9 @@ test "interface embedding" { }; // Test that our implementation satisfies all interfaces - comptime MonitoredRepository.satisfiedBy(TrackedRepository); - comptime Logger.satisfiedBy(TrackedRepository); - comptime Metrics.satisfiedBy(TrackedRepository); + comptime MonitoredRepository.validation.satisfiedBy(TrackedRepository); + comptime Logger.validation.satisfiedBy(TrackedRepository); + comptime Metrics.validation.satisfiedBy(TrackedRepository); // Test the actual implementation var repo = try TrackedRepository.init(std.testing.allocator); @@ -138,18 +138,18 @@ test "interface embedding" { test "interface embedding with conflicts" { // Two interfaces with conflicting method names - const BasicLogger = Interface(.{ - .log = fn (anytype, []const u8) void, + const IBasicLogger = Interface(.{ + .log = fn ([]const u8) void, }, null); - const MetricLogger = Interface(.{ - .log = fn (anytype, []const u8, u64) void, + const IMetricLogger = Interface(.{ + .log = fn ([]const u8, u64) void, }, null); // This should fail to compile due to conflicting 'log' methods - const ConflictingLogger = Interface(.{ - .write = fn (anytype, []const u8) void, - }, .{ BasicLogger, MetricLogger }); + const IConflictingLogger = Interface(.{ + .write = fn ([]const u8) void, + }, .{ IBasicLogger, IMetricLogger }); // Implementation that tries to satisfy both const BadImplementation = struct { @@ -166,7 +166,7 @@ test "interface embedding with conflicts" { // This should fail compilation with an ambiguous method error comptime { - if (ConflictingLogger.incompatibilities(BadImplementation).len == 0) { + if (IConflictingLogger.validation.incompatibilities(BadImplementation).len == 0) { @compileError("Should have detected conflicting 'log' methods"); } } @@ -174,19 +174,19 @@ test "interface embedding with conflicts" { test "nested interface embedding" { // Base interface - const Closer = Interface(.{ - .close = fn (anytype) void, + const ICloser = Interface(.{ + .close = fn () void, }, null); // Mid-level interface that embeds Closer - const Writer = Interface(.{ - .write = fn (anytype, []const u8) anyerror!void, - }, .{Closer}); + const IWriter = Interface(.{ + .write = fn ([]const u8) anyerror!void, + }, .{ICloser}); // Top-level interface that embeds Writer - const FileWriter = Interface(.{ - .flush = fn (anytype) anyerror!void, - }, .{Writer}); + const IFileWriter = Interface(.{ + .flush = fn () anyerror!void, + }, .{IWriter}); // Implementation that satisfies all interfaces const Implementation = struct { @@ -205,7 +205,386 @@ test "nested interface embedding" { }; // Should satisfy all interfaces - comptime FileWriter.satisfiedBy(Implementation); - comptime Writer.satisfiedBy(Implementation); - comptime Closer.satisfiedBy(Implementation); + comptime IFileWriter.validation.satisfiedBy(Implementation); + comptime IWriter.validation.satisfiedBy(Implementation); + comptime ICloser.validation.satisfiedBy(Implementation); +} + +test "high-level: runtime polymorphism with embedded interfaces" { + // Define a practical monitoring system using embedded interfaces + const Logger = Interface(.{ + .log = fn ([]const u8) void, + .setLevel = fn (u8) void, + }, null); + + const Metrics = Interface(.{ + .recordCount = fn ([]const u8, u64) void, + .getCount = fn ([]const u8) u64, + }, .{Logger}); + + const Repository = Interface(.{ + .save = fn (User) anyerror!u32, + .load = fn (u32) anyerror!?User, + }, .{Metrics}); + + // Implementation 1: In-memory repository with full monitoring + const InMemoryRepo = struct { + allocator: std.mem.Allocator, + users: std.AutoHashMap(u32, User), + metrics: std.StringHashMap(u64), + next_id: u32, + log_level: u8, + + const Self = @This(); + + pub fn init(allocator: std.mem.Allocator) !Self { + return .{ + .allocator = allocator, + .users = std.AutoHashMap(u32, User).init(allocator), + .metrics = std.StringHashMap(u64).init(allocator), + .next_id = 1, + .log_level = 0, + }; + } + + pub fn deinit(self: *Self) void { + self.users.deinit(); + self.metrics.deinit(); + } + + pub fn log(self: Self, msg: []const u8) void { + _ = self; + _ = msg; + // In production: write to log + } + + pub fn setLevel(self: *Self, level: u8) void { + self.log_level = level; + } + + pub fn recordCount(self: *Self, key: []const u8, value: u64) void { + self.metrics.put(key, value) catch {}; + } + + pub fn getCount(self: Self, key: []const u8) u64 { + return self.metrics.get(key) orelse 0; + } + + pub fn save(self: *Self, user: User) !u32 { + self.log("Saving user to memory"); + self.recordCount("saves", self.getCount("saves") + 1); + + var new_user = user; + new_user.id = self.next_id; + try self.users.put(self.next_id, new_user); + self.next_id += 1; + return new_user.id; + } + + pub fn load(self: *Self, id: u32) !?User { + self.recordCount("loads", self.getCount("loads") + 1); + return self.users.get(id); + } + }; + + // Implementation 2: Cache repository (simpler, just tracks hits/misses) + const CacheRepo = struct { + cache: std.AutoHashMap(u32, User), + hits: u64, + misses: u64, + log_enabled: bool, + + const Self = @This(); + + pub fn init(allocator: std.mem.Allocator) Self { + return .{ + .cache = std.AutoHashMap(u32, User).init(allocator), + .hits = 0, + .misses = 0, + .log_enabled = true, + }; + } + + pub fn deinit(self: *Self) void { + self.cache.deinit(); + } + + pub fn log(self: Self, msg: []const u8) void { + if (self.log_enabled) { + _ = msg; + // In production: write to cache log + } + } + + pub fn setLevel(self: *Self, level: u8) void { + self.log_enabled = level > 0; + } + + pub fn recordCount(self: *Self, key: []const u8, value: u64) void { + if (std.mem.eql(u8, key, "hits")) { + self.hits = value; + } else if (std.mem.eql(u8, key, "misses")) { + self.misses = value; + } + } + + pub fn getCount(self: Self, key: []const u8) u64 { + if (std.mem.eql(u8, key, "hits")) return self.hits; + if (std.mem.eql(u8, key, "misses")) return self.misses; + return 0; + } + + pub fn save(self: *Self, user: User) !u32 { + self.log("Caching user"); + try self.cache.put(user.id, user); + return user.id; + } + + pub fn load(self: *Self, id: u32) !?User { + if (self.cache.get(id)) |user| { + self.recordCount("hits", self.hits + 1); + return user; + } else { + self.recordCount("misses", self.misses + 1); + return null; + } + } + }; + + // Implementation 3: No-op repository for testing + const NoOpRepo = struct { + call_count: u64, + + pub fn init() @This() { + return .{ .call_count = 0 }; + } + + pub fn log(_: @This(), _: []const u8) void {} + pub fn setLevel(_: *@This(), _: u8) void {} + pub fn recordCount(_: *@This(), _: []const u8, _: u64) void {} + pub fn getCount(_: @This(), _: []const u8) u64 { + return 0; + } + + pub fn save(self: *@This(), _: User) !u32 { + self.call_count += 1; + return 999; + } + + pub fn load(self: *@This(), _: u32) !?User { + self.call_count += 1; + return null; + } + }; + + // Verify all implementations satisfy the interface + comptime Repository.validation.satisfiedBy(InMemoryRepo); + comptime Repository.validation.satisfiedBy(CacheRepo); + comptime Repository.validation.satisfiedBy(NoOpRepo); + + // Create instances + var in_memory = try InMemoryRepo.init(std.testing.allocator); + defer in_memory.deinit(); + + var cache = CacheRepo.init(std.testing.allocator); + defer cache.deinit(); + + var noop = NoOpRepo.init(); + + // Convert to interface objects for runtime polymorphism + const repo1 = Repository.from(&in_memory); + const repo2 = Repository.from(&cache); + const repo3 = Repository.from(&noop); + + // Store in heterogeneous collection + const repositories = [_]Repository{ repo1, repo2, repo3 }; + + // Use all repositories polymorphically + const test_user = User{ .id = 0, .name = "Alice", .email = "alice@example.com" }; + + for (repositories) |repo| { + _ = try repo.vtable.save(repo.ptr, test_user); + repo.vtable.log(repo.ptr, "Operation complete"); + } + + // Verify each implementation behaved correctly + try std.testing.expectEqual(@as(u64, 1), in_memory.getCount("saves")); + try std.testing.expectEqual(@as(u32, 1), noop.call_count); + + // Test loading through interface + const loaded = try repo1.vtable.load(repo1.ptr, 1); + try std.testing.expect(loaded != null); + try std.testing.expectEqualStrings("Alice", loaded.?.name); +} + +test "high-level: repository fallback chain with embedded interfaces" { + // Demonstrate a practical pattern: fallback chain of repositories + const Logger = Interface(.{ + .log = fn ([]const u8) void, + }, null); + + const Repository = Interface(.{ + .get = fn ([]const u8) anyerror!?[]const u8, + .put = fn ([]const u8, []const u8) anyerror!void, + }, .{Logger}); + + // L1 Cache - fast, limited capacity + const L1Cache = struct { + data: std.StringHashMap([]const u8), + hits: usize, + + pub fn init(allocator: std.mem.Allocator) @This() { + return .{ + .data = std.StringHashMap([]const u8).init(allocator), + .hits = 0, + }; + } + + pub fn deinit(self: *@This()) void { + self.data.deinit(); + } + + pub fn log(_: @This(), msg: []const u8) void { + _ = msg; + } + + pub fn get(self: *@This(), key: []const u8) !?[]const u8 { + if (self.data.get(key)) |value| { + self.hits += 1; + return value; + } + return null; + } + + pub fn put(self: *@This(), key: []const u8, value: []const u8) !void { + try self.data.put(key, value); + } + }; + + // L2 Cache - slower, larger capacity + const L2Cache = struct { + data: std.StringHashMap([]const u8), + hits: usize, + + pub fn init(allocator: std.mem.Allocator) @This() { + return .{ + .data = std.StringHashMap([]const u8).init(allocator), + .hits = 0, + }; + } + + pub fn deinit(self: *@This()) void { + self.data.deinit(); + } + + pub fn log(_: @This(), msg: []const u8) void { + _ = msg; + } + + pub fn get(self: *@This(), key: []const u8) !?[]const u8 { + if (self.data.get(key)) |value| { + self.hits += 1; + return value; + } + return null; + } + + pub fn put(self: *@This(), key: []const u8, value: []const u8) !void { + try self.data.put(key, value); + } + }; + + // Backing store + const BackingStore = struct { + data: std.StringHashMap([]const u8), + reads: usize, + + pub fn init(allocator: std.mem.Allocator) @This() { + return .{ + .data = std.StringHashMap([]const u8).init(allocator), + .reads = 0, + }; + } + + pub fn deinit(self: *@This()) void { + self.data.deinit(); + } + + pub fn log(_: @This(), msg: []const u8) void { + _ = msg; + } + + pub fn get(self: *@This(), key: []const u8) !?[]const u8 { + self.reads += 1; + return self.data.get(key); + } + + pub fn put(self: *@This(), key: []const u8, value: []const u8) !void { + try self.data.put(key, value); + } + }; + + comptime Repository.validation.satisfiedBy(L1Cache); + comptime Repository.validation.satisfiedBy(L2Cache); + comptime Repository.validation.satisfiedBy(BackingStore); + + // Set up the fallback chain + var l1 = L1Cache.init(std.testing.allocator); + defer l1.deinit(); + + var l2 = L2Cache.init(std.testing.allocator); + defer l2.deinit(); + + var backing = BackingStore.init(std.testing.allocator); + defer backing.deinit(); + + // Pre-populate backing store + try backing.put("key1", "value1"); + try backing.put("key2", "value2"); + + // Create interface chain + const chain = [_]Repository{ + Repository.from(&l1), + Repository.from(&l2), + Repository.from(&backing), + }; + + // Function to get value through fallback chain + const getValue = struct { + fn get(repos: []const Repository, key: []const u8) !?[]const u8 { + for (repos) |repo| { + if (try repo.vtable.get(repo.ptr, key)) |value| { + return value; + } + } + return null; + } + }.get; + + // First access - should hit backing store + const val1 = try getValue(&chain, "key1"); + try std.testing.expect(val1 != null); + try std.testing.expectEqualStrings("value1", val1.?); + try std.testing.expectEqual(@as(usize, 0), l1.hits); + try std.testing.expectEqual(@as(usize, 0), l2.hits); + try std.testing.expectEqual(@as(usize, 1), backing.reads); + + // Populate L2 cache + try chain[1].vtable.put(chain[1].ptr, "key1", "value1"); + + // Second access - should hit L2 + const val2 = try getValue(&chain, "key1"); + try std.testing.expect(val2 != null); + try std.testing.expectEqual(@as(usize, 1), l2.hits); + + // Populate L1 cache + try chain[0].vtable.put(chain[0].ptr, "key1", "value1"); + + // Third access - should hit L1 + const val3 = try getValue(&chain, "key1"); + try std.testing.expect(val3 != null); + try std.testing.expectEqual(@as(usize, 1), l1.hits); + + // Still only 1 backing store read + try std.testing.expectEqual(@as(usize, 1), backing.reads); } diff --git a/test/inference.zig b/test/inference.zig new file mode 100644 index 0000000..282a31f --- /dev/null +++ b/test/inference.zig @@ -0,0 +1,252 @@ +const std = @import("std"); +const Interface = @import("interface").Interface; + +// Define our Generative AI API interface +const AIProvider = Interface(.{ + .generate = fn ([]const u8) anyerror![]const u8, + .embed = fn ([]const u8) anyerror![256]f16, + .query = fn ([]const u8) anyerror![][]const u8, +}, null); + +fn generate(provider: AIProvider, prompt: []const u8) ![]const u8 { + return provider.vtable.generate(provider.ptr, prompt); +} + +fn embed(provider: AIProvider, data: []const u8) ![256]f16 { + return provider.vtable.embed(provider.ptr, data); +} +fn query(provider: AIProvider, prompt: []const u8) ![][]const u8 { + return provider.vtable.query(provider.ptr, prompt); +} + +// OpenAI Mock Implementation +pub const OpenAIMock = struct { + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator) OpenAIMock { + return .{ + .allocator = allocator, + }; + } + + pub fn deinit(self: *OpenAIMock) void { + _ = self; + } + + pub fn generate(self: *OpenAIMock, prompt: []const u8) ![]const u8 { + _ = self; + _ = prompt; + return "This is a mock response from OpenAI API"; + } + + pub fn embed(self: *OpenAIMock, input: []const u8) ![256]f16 { + _ = self; + _ = input; + var embeddings: [256]f16 = undefined; + for (&embeddings, 0..) |*val, i| { + val.* = @floatFromInt(@as(i16, @intCast(i))); + } + return embeddings; + } + + pub fn query(self: *OpenAIMock, input: []const u8) ![][]const u8 { + _ = input; + const results = try self.allocator.alloc([]const u8, 3); + results[0] = "OpenAI result 1"; + results[1] = "OpenAI result 2"; + results[2] = "OpenAI result 3"; + return results; + } +}; + +// Anthropic Mock Implementation +pub const AnthropicMock = struct { + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator) AnthropicMock { + return .{ + .allocator = allocator, + }; + } + + pub fn deinit(self: *AnthropicMock) void { + _ = self; + } + + pub fn generate(self: *AnthropicMock, prompt: []const u8) ![]const u8 { + _ = self; + _ = prompt; + return "This is a mock response from Anthropic Claude API"; + } + + pub fn embed(self: *AnthropicMock, input: []const u8) ![256]f16 { + _ = self; + _ = input; + var embeddings: [256]f16 = undefined; + for (&embeddings, 0..) |*val, i| { + // Use a different pattern than OpenAI to distinguish them + val.* = @floatFromInt(@as(i16, @intCast(255 - i))); + } + return embeddings; + } + + pub fn query(self: *AnthropicMock, input: []const u8) ![][]const u8 { + _ = input; + const results = try self.allocator.alloc([]const u8, 2); + results[0] = "Anthropic result 1"; + results[1] = "Anthropic result 2"; + return results; + } +}; + +// Example function that works with any Generative AI implementation +fn processPrompt(api: anytype, prompt: []const u8) ![]const u8 { + comptime AIProvider.validation.satisfiedBy(@TypeOf(api.*)); + return try api.generate(prompt); +} + +test "OpenAI mock satisfies interface" { + var openai = OpenAIMock.init(std.testing.allocator); + defer openai.deinit(); + + // Verify at comptime that our implementation satisfies the interface + comptime AIProvider.validation.satisfiedBy(OpenAIMock); + + // Test generate + const response = try openai.generate("Test prompt"); + try std.testing.expectEqualStrings("This is a mock response from OpenAI API", response); + + // Test embed + const embeddings = try openai.embed("Test input"); + try std.testing.expectEqual(@as(f16, 0.0), embeddings[0]); + try std.testing.expectEqual(@as(f16, 255.0), embeddings[255]); + + // Test query + const results = try openai.query("Test query"); + defer std.testing.allocator.free(results); + try std.testing.expectEqual(@as(usize, 3), results.len); + try std.testing.expectEqualStrings("OpenAI result 1", results[0]); +} + +test "Anthropic mock satisfies interface" { + var anthropic = AnthropicMock.init(std.testing.allocator); + defer anthropic.deinit(); + + // Verify at comptime that our implementation satisfies the interface + comptime AIProvider.validation.satisfiedBy(AnthropicMock); + + // Test generate + const response = try anthropic.generate("Test prompt"); + try std.testing.expectEqualStrings("This is a mock response from Anthropic Claude API", response); + + // Test embed + const embeddings = try anthropic.embed("Test input"); + try std.testing.expectEqual(@as(f16, 255.0), embeddings[0]); + try std.testing.expectEqual(@as(f16, 0.0), embeddings[255]); + + // Test query + const results = try anthropic.query("Test query"); + defer std.testing.allocator.free(results); + try std.testing.expectEqual(@as(usize, 2), results.len); + try std.testing.expectEqualStrings("Anthropic result 1", results[0]); +} + +test "processPrompt works with both implementations" { + // Test with OpenAI + var openai = OpenAIMock.init(std.testing.allocator); + defer openai.deinit(); + + const openai_response = try processPrompt(&openai, "Hello"); + try std.testing.expectEqualStrings("This is a mock response from OpenAI API", openai_response); + + // Test with Anthropic + var anthropic = AnthropicMock.init(std.testing.allocator); + defer anthropic.deinit(); + + const anthropic_response = try processPrompt(&anthropic, "Hello"); + try std.testing.expectEqualStrings("This is a mock response from Anthropic Claude API", anthropic_response); +} + +const Wrong = struct {}; + +test "Inference wrapper with VTable-based providers" { + // Create OpenAI inference instance using VTable + var openai_provider = OpenAIMock.init(std.testing.allocator); + defer openai_provider.deinit(); + const openai_interface = AIProvider.from(&openai_provider); + + // Test OpenAI generate + const openai_response = try generate(openai_interface, "Test prompt"); + try std.testing.expectEqualStrings("This is a mock response from OpenAI API", openai_response); + + // Test OpenAI embed + const openai_embeddings = try embed(openai_interface, "Test input"); + try std.testing.expectEqual(@as(f16, 0.0), openai_embeddings[0]); + try std.testing.expectEqual(@as(f16, 255.0), openai_embeddings[255]); + + // Test OpenAI query + const openai_results = try query(openai_interface, "Test query"); + defer std.testing.allocator.free(openai_results); + try std.testing.expectEqual(@as(usize, 3), openai_results.len); + try std.testing.expectEqualStrings("OpenAI result 1", openai_results[0]); + try std.testing.expectEqualStrings("OpenAI result 2", openai_results[1]); + try std.testing.expectEqualStrings("OpenAI result 3", openai_results[2]); + + // Create Anthropic inference instance using VTable + var anthropic_provider = AnthropicMock.init(std.testing.allocator); + defer anthropic_provider.deinit(); + const anthropic_interface = AIProvider.from(&anthropic_provider); + + // Test Anthropic generate + const anthropic_response = try generate(anthropic_interface, "Test prompt"); + try std.testing.expectEqualStrings("This is a mock response from Anthropic Claude API", anthropic_response); + + // Test Anthropic embed + const anthropic_embeddings = try embed(anthropic_interface, "Test input"); + try std.testing.expectEqual(@as(f16, 255.0), anthropic_embeddings[0]); + try std.testing.expectEqual(@as(f16, 0.0), anthropic_embeddings[255]); + + // Test Anthropic query + const anthropic_results = try query(anthropic_interface, "Test query"); + defer std.testing.allocator.free(anthropic_results); + try std.testing.expectEqual(@as(usize, 2), anthropic_results.len); + try std.testing.expectEqualStrings("Anthropic result 1", anthropic_results[0]); + try std.testing.expectEqualStrings("Anthropic result 2", anthropic_results[1]); +} + +test "Runtime polymorphism with heterogeneous providers" { + // Create both providers + var openai_provider = OpenAIMock.init(std.testing.allocator); + defer openai_provider.deinit(); + var anthropic_provider = AnthropicMock.init(std.testing.allocator); + defer anthropic_provider.deinit(); + + // Store different provider types in an array (runtime polymorphism!) + const providers = [_]AIProvider{ + AIProvider.from(&openai_provider), + AIProvider.from(&anthropic_provider), + }; + + // Test that we can call through the array and get different results + const openai_response = try generate(providers[0], "prompt"); + const anthropic_response = try generate(providers[1], "prompt"); + + try std.testing.expectEqualStrings("This is a mock response from OpenAI API", openai_response); + try std.testing.expectEqualStrings("This is a mock response from Anthropic Claude API", anthropic_response); + + // Test embeddings are different + const openai_embed = try embed(providers[0], "input"); + const anthropic_embed = try embed(providers[1], "input"); + + try std.testing.expectEqual(@as(f16, 0.0), openai_embed[0]); + try std.testing.expectEqual(@as(f16, 255.0), anthropic_embed[0]); + + // Test query returns different number of results + const openai_query = try query(providers[0], "query"); + defer std.testing.allocator.free(openai_query); + const anthropic_query = try query(providers[1], "query"); + defer std.testing.allocator.free(anthropic_query); + + try std.testing.expectEqual(@as(usize, 3), openai_query.len); + try std.testing.expectEqual(@as(usize, 2), anthropic_query.len); +} diff --git a/test/simple.zig b/test/simple.zig index b0a9465..dcf4bc7 100644 --- a/test/simple.zig +++ b/test/simple.zig @@ -9,13 +9,13 @@ const User = struct { }; // Define our Repository interface with multiple methods -// Note the anytype to indicate pointer methods +// Interface() now returns the vtable-based type directly const Repository = Interface(.{ - .create = fn (anytype, User) anyerror!u32, - .findById = fn (anytype, u32) anyerror!?User, - .update = fn (anytype, User) anyerror!void, - .delete = fn (anytype, u32) anyerror!void, - .findByEmail = fn (anytype, []const u8) anyerror!?User, + .create = fn (User) anyerror!u32, + .findById = fn (u32) anyerror!?User, + .update = fn (User) anyerror!void, + .delete = fn (u32) anyerror!void, + .findByEmail = fn ([]const u8) anyerror!?User, }, null); // Implement a simple in-memory repository @@ -73,9 +73,10 @@ pub const InMemoryRepository = struct { } }; -// Function that works with any Repository implementation +// Function that works with any Repository implementation (compile-time duck typing) fn createUser(repo: anytype, name: []const u8, email: []const u8) !User { - comptime Repository.satisfiedBy(@TypeOf(repo.*)); // Required to be called by function author + // Use .validation.satisfiedBy() to verify interface compliance at compile time + comptime Repository.validation.satisfiedBy(@TypeOf(repo.*)); const user = User{ .id = 0, @@ -91,14 +92,31 @@ fn createUser(repo: anytype, name: []const u8, email: []const u8) !User { }; } +// Function that works with any Repository implementation via vtable (runtime polymorphism) +fn dynCreateUser(repo: Repository, name: []const u8, email: []const u8) !User { + const user = User{ + .id = 0, + .name = name, + .email = email, + }; + + const id = try repo.vtable.create(repo.ptr, user); + return User{ + .id = id, + .name = name, + .email = email, + }; +} + test "repository interface" { var repo = InMemoryRepository.init(std.testing.allocator); defer repo.deinit(); // Verify at comptime that our implementation satisfies the interface - comptime Repository.satisfiedBy(@TypeOf(repo)); // Required to be called by function author + // Use .validation namespace for compile-time validation + comptime Repository.validation.satisfiedBy(@TypeOf(repo)); // or, can pass the concrete struct type directly: - comptime Repository.satisfiedBy(InMemoryRepository); + comptime Repository.validation.satisfiedBy(InMemoryRepository); // Test create and findById const user1 = try createUser(&repo, "John Doe", "john@example.com"); @@ -132,3 +150,40 @@ test "repository interface" { })); try std.testing.expectError(error.UserNotFound, repo.delete(999)); } + +test "dynamic repository interface" { + var repo = InMemoryRepository.init(std.testing.allocator); + defer repo.deinit(); + + // Test create and findById + const user1 = try dynCreateUser(Repository.from(&repo), "John Doe", "john@example.com"); + const found = try repo.findById(user1.id); + try std.testing.expect(found != null); + try std.testing.expectEqualStrings("John Doe", found.?.name); + + // Test findByEmail + const by_email = try repo.findByEmail("john@example.com"); + try std.testing.expect(by_email != null); + try std.testing.expectEqual(user1.id, by_email.?.id); + + // Test update + var updated_user = user1; + updated_user.name = "Johnny Doe"; + try repo.update(updated_user); + const found_updated = try repo.findById(user1.id); + try std.testing.expect(found_updated != null); + try std.testing.expectEqualStrings("Johnny Doe", found_updated.?.name); + + // Test delete + try repo.delete(user1.id); + const not_found = try repo.findById(user1.id); + try std.testing.expect(not_found == null); + + // Test error cases + try std.testing.expectError(error.UserNotFound, repo.update(User{ + .id = 999, + .name = "Not Found", + .email = "none@example.com", + })); + try std.testing.expectError(error.UserNotFound, repo.delete(999)); +} diff --git a/test/vtable.zig b/test/vtable.zig new file mode 100644 index 0000000..5c8bb88 --- /dev/null +++ b/test/vtable.zig @@ -0,0 +1,83 @@ +const std = @import("std"); +const Interface = @import("interface").Interface; + +// Simple interface to test VTable generation +const IWriter = Interface(.{ + .write = fn ([]const u8) anyerror!usize, +}, null); + +// Generate the VTable-based runtime type +const Writer = IWriter; + +// Test implementation - Simplified with auto-generated wrappers +const BufferWriter = struct { + buffer: std.ArrayList(u8), + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator) BufferWriter { + return .{ + .buffer = std.ArrayList(u8){}, + .allocator = allocator, + }; + } + + pub fn deinit(self: *BufferWriter) void { + self.buffer.deinit(self.allocator); + } + + pub fn write(self: *BufferWriter, data: []const u8) !usize { + try self.buffer.appendSlice(self.allocator, data); + return data.len; + } + + pub fn getWritten(self: *const BufferWriter) []const u8 { + return self.buffer.items; + } +}; + +test "vtable interface type generation" { + // Verify the interface type was created + const VTableType = Writer.VTable; + const vtable_fields = std.meta.fields(VTableType); + + try std.testing.expectEqual(@as(usize, 1), vtable_fields.len); + try std.testing.expectEqualStrings("write", vtable_fields[0].name); +} + +test "vtable interface runtime usage with from()" { + var buffer_writer = BufferWriter.init(std.testing.allocator); + defer buffer_writer.deinit(); + + // Create interface wrapper using from() - no manual wrappers needed! + const writer_interface = Writer.from(&buffer_writer); + + // Use through the interface + const written = try writer_interface.vtable.write(writer_interface.ptr, "Hello, "); + try std.testing.expectEqual(@as(usize, 7), written); + + const written2 = try writer_interface.vtable.write(writer_interface.ptr, "World!"); + try std.testing.expectEqual(@as(usize, 6), written2); + + // Verify the data was written + try std.testing.expectEqualStrings("Hello, World!", buffer_writer.getWritten()); +} + +test "vtable interface with multiple implementations" { + // First implementation + var buffer_writer1 = BufferWriter.init(std.testing.allocator); + defer buffer_writer1.deinit(); + + var buffer_writer2 = BufferWriter.init(std.testing.allocator); + defer buffer_writer2.deinit(); + + // Create interface wrappers - auto-generated VTables + const writer1 = Writer.from(&buffer_writer1); + const writer2 = Writer.from(&buffer_writer2); + + // Write to different writers through same interface + _ = try writer1.vtable.write(writer1.ptr, "First"); + _ = try writer2.vtable.write(writer2.ptr, "Second"); + + try std.testing.expectEqualStrings("First", buffer_writer1.getWritten()); + try std.testing.expectEqualStrings("Second", buffer_writer2.getWritten()); +}