Skip to content

Commit 6d7fb90

Browse files
add modelprovider interface and providers
Signed-off-by: Avinash Singh <[email protected]>
1 parent 643b5a1 commit 6d7fb90

File tree

10 files changed

+951
-71
lines changed

10 files changed

+951
-71
lines changed

cmd/modelfile/generate.go

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,20 @@ import (
2525
"github.com/spf13/viper"
2626

2727
configmodelfile "github.com/modelpack/modctl/pkg/config/modelfile"
28-
"github.com/modelpack/modctl/pkg/hfhub"
2928
"github.com/modelpack/modctl/pkg/modelfile"
29+
"github.com/modelpack/modctl/pkg/modelprovider"
3030
)
3131

3232
var generateConfig = configmodelfile.NewGenerateConfig()
3333

3434
// generateCmd represents the modelfile tools command for generating modelfile.
3535
var generateCmd = &cobra.Command{
3636
Use: "generate [flags] [<path>]",
37-
Short: "Generate a modelfile from a local workspace or Hugging Face model",
38-
Long: `Generate a modelfile from either a local directory containing model files or by downloading a model from Hugging Face.
37+
Short: "Generate a modelfile from a local workspace or remote model provider",
38+
Long: `Generate a modelfile from either a local directory containing model files or by downloading a model from a supported provider.
3939
4040
The workspace must be a directory including model files and model configuration files.
41-
Alternatively, use --model_url to download a model from Hugging Face Hub.`,
41+
Alternatively, use --model_url to download a model from a supported provider (e.g., HuggingFace, ModelScope).`,
4242
Example: ` # Generate from local directory
4343
modctl modelfile generate ./my-model-dir
4444
@@ -48,6 +48,9 @@ Alternatively, use --model_url to download a model from Hugging Face Hub.`,
4848
# Generate from Hugging Face using short form
4949
modctl modelfile generate --model_url meta-llama/Llama-2-7b-hf
5050
51+
# Generate from ModelScope
52+
modctl modelfile generate --model_url https://modelscope.cn/models/qwen/Qwen-7B
53+
5154
# Generate with custom output path
5255
modctl modelfile generate ./my-model-dir --output ./output/modelfile.yaml
5356
@@ -97,7 +100,7 @@ func init() {
97100
flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory")
98101
flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace")
99102
flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile")
100-
flags.StringVar(&generateConfig.ModelURL, "model_url", "", "download model from Hugging Face (format: owner/repo or full URL)")
103+
flags.StringVar(&generateConfig.ModelURL, "model_url", "", "download model from a supported provider (HuggingFace: owner/repo or full URL, ModelScope: full URL)")
101104

102105
// Mark the ignore-unrecognized-file-types flag as deprecated and hidden
103106
flags.MarkDeprecated("ignore-unrecognized-file-types", "this flag will be removed in the next release")
@@ -114,23 +117,32 @@ func runGenerate(ctx context.Context) error {
114117
if generateConfig.ModelURL != "" {
115118
fmt.Printf("Model URL provided: %s\n", generateConfig.ModelURL)
116119

117-
// Check if user is authenticated with Hugging Face
118-
if err := hfhub.CheckHuggingFaceAuth(); err != nil {
119-
return fmt.Errorf("authentication check failed: %w", err)
120+
// Get the appropriate provider for this URL
121+
registry := modelprovider.NewRegistry()
122+
provider, err := registry.GetProvider(generateConfig.ModelURL)
123+
if err != nil {
124+
return fmt.Errorf("unsupported model URL: %w", err)
125+
}
126+
127+
fmt.Printf("Using provider: %s\n", provider.Name())
128+
129+
// Check if user is authenticated with the provider
130+
if err := provider.CheckAuth(); err != nil {
131+
return fmt.Errorf("%s authentication check failed: %w", provider.Name(), err)
120132
}
121133

122134
// Create a temporary directory for downloading the model
123135
// Clean up the temporary directory after the function returns
124-
tmpDir, err := os.MkdirTemp("", "modctl-hf-downloads-*")
136+
tmpDir, err := os.MkdirTemp("", "modctl-model-downloads-*")
125137
if err != nil {
126138
return fmt.Errorf("failed to create temporary directory: %w", err)
127139
}
128140
defer os.RemoveAll(tmpDir)
129141

130142
// Download the model
131-
downloadPath, err := hfhub.DownloadModel(ctx, generateConfig.ModelURL, tmpDir)
143+
downloadPath, err := provider.DownloadModel(ctx, generateConfig.ModelURL, tmpDir)
132144
if err != nil {
133-
return fmt.Errorf("failed to download model: %w", err)
145+
return fmt.Errorf("failed to download model from %s: %w", provider.Name(), err)
134146
}
135147

136148
// Update workspace to the downloaded model path

pkg/hfhub/download.go renamed to pkg/modelprovider/huggingface/downloader.go

Lines changed: 15 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package hfhub
17+
package huggingface
1818

1919
import (
2020
"context"
@@ -29,11 +29,11 @@ import (
2929
)
3030

3131
const (
32-
HuggingFaceBaseURL = "https://huggingface.co"
32+
huggingFaceBaseURL = "https://huggingface.co"
3333
)
3434

35-
// ParseModelURL parses a Hugging Face model URL and extracts the owner and repository name
36-
func ParseModelURL(modelURL string) (owner, repo string, err error) {
35+
// parseModelURL parses a HuggingFace model URL and extracts the owner and repository name
36+
func parseModelURL(modelURL string) (owner, repo string, err error) {
3737
// Handle both full URLs and short-form repo names
3838
modelURL = strings.TrimSpace(modelURL)
3939

@@ -50,7 +50,7 @@ func ParseModelURL(modelURL string) (owner, repo string, err error) {
5050
// Expected format: https://huggingface.co/owner/repo
5151
parts := strings.Split(strings.Trim(u.Path, "/"), "/")
5252
if len(parts) < 2 {
53-
return "", "", fmt.Errorf("invalid Hugging Face URL format, expected https://huggingface.co/owner/repo")
53+
return "", "", fmt.Errorf("invalid HuggingFace URL format, expected https://huggingface.co/owner/repo")
5454
}
5555

5656
owner = parts[0]
@@ -73,45 +73,8 @@ func ParseModelURL(modelURL string) (owner, repo string, err error) {
7373
return owner, repo, nil
7474
}
7575

76-
// DownloadModel downloads a model from Hugging Face using the huggingface-cli
77-
// It assumes the user is already logged in via `huggingface-cli login`
78-
func DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) {
79-
owner, repo, err := ParseModelURL(modelURL)
80-
if err != nil {
81-
return "", err
82-
}
83-
84-
repoID := fmt.Sprintf("%s/%s", owner, repo)
85-
86-
// Check if huggingface-cli is available
87-
if _, err := exec.LookPath("huggingface-cli"); err != nil {
88-
return "", fmt.Errorf("huggingface-cli not found in PATH. Please install it using: pip install huggingface_hub[cli]")
89-
}
90-
91-
// Create destination directory if it doesn't exist
92-
if err := os.MkdirAll(destDir, 0755); err != nil {
93-
return "", fmt.Errorf("failed to create destination directory: %w", err)
94-
}
95-
96-
// Construct the download path
97-
downloadPath := filepath.Join(destDir, repo)
98-
99-
// Use huggingface-cli to download the model
100-
// The --local-dir-use-symlinks=False flag ensures files are copied, not symlinked
101-
cmd := exec.CommandContext(ctx, "huggingface-cli", "download", repoID, "--local-dir", downloadPath, "--local-dir-use-symlinks", "False")
102-
103-
cmd.Stdout = os.Stdout
104-
cmd.Stderr = os.Stderr
105-
106-
if err := cmd.Run(); err != nil {
107-
return "", fmt.Errorf("failed to download model using huggingface-cli: %w", err)
108-
}
109-
110-
return downloadPath, nil
111-
}
112-
113-
// CheckHuggingFaceAuth checks if the user is authenticated with Hugging Face
114-
func CheckHuggingFaceAuth() error {
76+
// checkHuggingFaceAuth checks if the user is authenticated with HuggingFace
77+
func checkHuggingFaceAuth() error {
11578
// Try to find the HF token
11679
token := os.Getenv("HF_TOKEN")
11780
if token != "" {
@@ -139,11 +102,11 @@ func CheckHuggingFaceAuth() error {
139102
}
140103
}
141104

142-
return fmt.Errorf("not authenticated with Hugging Face. Please run: huggingface-cli login")
105+
return fmt.Errorf("not authenticated with HuggingFace. Please run: huggingface-cli login")
143106
}
144107

145-
// GetToken retrieves the Hugging Face token from environment or token file
146-
func GetToken() (string, error) {
108+
// getToken retrieves the HuggingFace token from environment or token file
109+
func getToken() (string, error) {
147110
// First check environment variable
148111
token := os.Getenv("HF_TOKEN")
149112
if token != "" {
@@ -165,16 +128,16 @@ func GetToken() (string, error) {
165128
return strings.TrimSpace(string(data)), nil
166129
}
167130

168-
// DownloadFile downloads a single file from Hugging Face
169-
func DownloadFile(ctx context.Context, owner, repo, filename, destPath string) error {
170-
token, err := GetToken()
131+
// downloadFile downloads a single file from HuggingFace
132+
func downloadFile(ctx context.Context, owner, repo, filename, destPath string) error {
133+
token, err := getToken()
171134
if err != nil {
172-
return fmt.Errorf("failed to get Hugging Face token: %w", err)
135+
return fmt.Errorf("failed to get HuggingFace token: %w", err)
173136
}
174137

175138
// Construct the download URL
176139
// Format: https://huggingface.co/{owner}/{repo}/resolve/main/{filename}
177-
fileURL := fmt.Sprintf("%s/%s/%s/resolve/main/%s", HuggingFaceBaseURL, owner, repo, filename)
140+
fileURL := fmt.Sprintf("%s/%s/%s/resolve/main/%s", huggingFaceBaseURL, owner, repo, filename)
178141

179142
req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil)
180143
if err != nil {

pkg/hfhub/download_test.go renamed to pkg/modelprovider/huggingface/downloader_test.go

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package hfhub
17+
package huggingface
1818

1919
import (
2020
"strings"
@@ -62,7 +62,7 @@ func TestParseModelURL(t *testing.T) {
6262
name: "invalid format - missing repo",
6363
modelURL: "https://huggingface.co/meta-llama",
6464
wantErr: true,
65-
errContains: "invalid Hugging Face URL format",
65+
errContains: "invalid HuggingFace URL format",
6666
},
6767
{
6868
name: "invalid format - only owner",
@@ -87,31 +87,82 @@ func TestParseModelURL(t *testing.T) {
8787

8888
for _, tt := range tests {
8989
t.Run(tt.name, func(t *testing.T) {
90-
owner, repo, err := ParseModelURL(tt.modelURL)
90+
owner, repo, err := parseModelURL(tt.modelURL)
9191

9292
if tt.wantErr {
9393
if err == nil {
94-
t.Errorf("ParseModelURL() expected error but got nil")
94+
t.Errorf("parseModelURL() expected error but got nil")
9595
return
9696
}
9797
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
98-
t.Errorf("ParseModelURL() error = %v, want error containing %v", err, tt.errContains)
98+
t.Errorf("parseModelURL() error = %v, want error containing %v", err, tt.errContains)
9999
}
100100
return
101101
}
102102

103103
if err != nil {
104-
t.Errorf("ParseModelURL() unexpected error = %v", err)
104+
t.Errorf("parseModelURL() unexpected error = %v", err)
105105
return
106106
}
107107

108108
if owner != tt.wantOwner {
109-
t.Errorf("ParseModelURL() owner = %v, want %v", owner, tt.wantOwner)
109+
t.Errorf("parseModelURL() owner = %v, want %v", owner, tt.wantOwner)
110110
}
111111

112112
if repo != tt.wantRepo {
113-
t.Errorf("ParseModelURL() repo = %v, want %v", repo, tt.wantRepo)
113+
t.Errorf("parseModelURL() repo = %v, want %v", repo, tt.wantRepo)
114114
}
115115
})
116116
}
117117
}
118+
119+
func TestProvider_SupportsURL(t *testing.T) {
120+
provider := New()
121+
122+
tests := []struct {
123+
name string
124+
url string
125+
want bool
126+
}{
127+
{
128+
name: "full HuggingFace URL",
129+
url: "https://huggingface.co/meta-llama/Llama-2-7b-hf",
130+
want: true,
131+
},
132+
{
133+
name: "short form repo",
134+
url: "meta-llama/Llama-2-7b-hf",
135+
want: true,
136+
},
137+
{
138+
name: "ModelScope URL",
139+
url: "https://modelscope.cn/models/owner/repo",
140+
want: false,
141+
},
142+
{
143+
name: "invalid format",
144+
url: "just-a-string",
145+
want: false,
146+
},
147+
{
148+
name: "HTTP URL",
149+
url: "http://example.com/owner/repo",
150+
want: false,
151+
},
152+
}
153+
154+
for _, tt := range tests {
155+
t.Run(tt.name, func(t *testing.T) {
156+
if got := provider.SupportsURL(tt.url); got != tt.want {
157+
t.Errorf("Provider.SupportsURL() = %v, want %v", got, tt.want)
158+
}
159+
})
160+
}
161+
}
162+
163+
func TestProvider_Name(t *testing.T) {
164+
provider := New()
165+
if got := provider.Name(); got != "huggingface" {
166+
t.Errorf("Provider.Name() = %v, want %v", got, "huggingface")
167+
}
168+
}

0 commit comments

Comments
 (0)