Skip to content

Conversation

@etherealsunshine
Copy link
Contributor

@etherealsunshine etherealsunshine commented Aug 6, 2025

Description

Adds support for Apple Metal Perfomance shaders when loading models from checkpoint in UME. Addresses #180

Type of Change

  • Bug fix
  • New feature
  • Documentation update
  • Performance improvement
  • Code refactoring

@etherealsunshine
Copy link
Contributor Author

etherealsunshine commented Aug 6, 2025

On the lookout for a python implementation of Flash Attention for MPS architecture, as of right now, I can't seem to find any.

@ncfrey
Copy link
Contributor

ncfrey commented Aug 13, 2025

@etherealsunshine have you tried this out with MPS?

@etherealsunshine
Copy link
Contributor Author

Hi @ncfrey, so MPS gets auto-detected correctly, Flash Attention properly falls back to SDPA on MPS, and basic tensor operations work. I just tested the device detection and configuration logic directly since I don't have access to the pretrained models - but the core MPS handling works correctly. Dont have immediate plans to for training a custom model yet, so if you think thats sufficient I can add these validation tests to the tests. Let me know what you think!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants