Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 86 additions & 5 deletions cmd/modelfile/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,62 @@ import (

configmodelfile "github.com/modelpack/modctl/pkg/config/modelfile"
"github.com/modelpack/modctl/pkg/modelfile"
"github.com/modelpack/modctl/pkg/modelprovider"
)

var generateConfig = configmodelfile.NewGenerateConfig()

// generateCmd represents the modelfile tools command for generating modelfile.
var generateCmd = &cobra.Command{
Use: "generate [flags] <path>",
Short: "A command line tool for generating modelfile in the workspace, the workspace must be a directory including model files and model configuration files",
Args: cobra.ExactArgs(1),
Use: "generate [flags] [<path>]",
Short: "Generate a modelfile from a local workspace or remote model provider",
Long: `Generate a modelfile from either a local directory containing model files or by downloading a model from a supported provider.

The workspace must be a directory including model files and model configuration files.
Alternatively, use --model-url to download a model from a supported provider (e.g., HuggingFace, ModelScope).

For short-form URLs (owner/repo), you must explicitly specify the provider using --provider flag.
Full URLs with domain names will auto-detect the provider.`,
Example: ` # Generate from local directory
modctl modelfile generate ./my-model-dir

# Generate from Hugging Face using full URL (auto-detects provider)
modctl modelfile generate --model-url https://huggingface.co/meta-llama/Llama-2-7b-hf

# Generate from Hugging Face using short form (requires --provider)
modctl modelfile generate --model-url meta-llama/Llama-2-7b-hf --provider huggingface

# Generate from ModelScope using full URL (auto-detects provider)
modctl modelfile generate --model-url https://modelscope.cn/models/qwen/Qwen-7B

# Generate from ModelScope using short form (requires --provider)
modctl modelfile generate --model-url qwen/Qwen-7B --provider modelscope

# Generate with custom output path
modctl modelfile generate ./my-model-dir --output ./output/modelfile.yaml

# Generate with metadata overrides
modctl modelfile generate ./my-model-dir --name my-custom-model --family llama3`,
Args: cobra.MaximumNArgs(1),
DisableAutoGenTag: true,
SilenceUsage: true,
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
RunE: func(cmd *cobra.Command, args []string) error {
if err := generateConfig.Convert(args[0]); err != nil {
// If model-url is provided, path is optional
workspace := "."
if len(args) > 0 {
workspace = args[0]
}

// Validate that either path or model-url is provided
if generateConfig.ModelURL != "" && len(args) > 0 {
return fmt.Errorf("the <path> argument and the --model-url flag are mutually exclusive")
}
if generateConfig.ModelURL == "" && len(args) == 0 {
return fmt.Errorf("either a <path> argument or the --model-url flag must be provided")
}

if err := generateConfig.Convert(workspace); err != nil {
return err
}

Expand All @@ -64,6 +106,8 @@ func init() {
flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory")
flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace")
flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile")
flags.StringVar(&generateConfig.ModelURL, "model-url", "", "download model from a supported provider (full URL or short-form with --provider)")
flags.StringVarP(&generateConfig.Provider, "provider", "p", "", "explicitly specify the provider for short-form URLs (huggingface, modelscope)")
flags.StringArrayVar(&generateConfig.ExcludePatterns, "exclude", []string{}, "specify glob patterns to exclude files/directories (e.g. *.log, checkpoints/*)")

// Mark the ignore-unrecognized-file-types flag as deprecated and hidden
Expand All @@ -76,7 +120,44 @@ func init() {
}

// runGenerate runs the generate modelfile.
func runGenerate(_ context.Context) error {
func runGenerate(ctx context.Context) error {
// If model URL is provided, download the model first
if generateConfig.ModelURL != "" {
fmt.Printf("Model URL provided: %s\n", generateConfig.ModelURL)

// Get the appropriate provider for this URL
registry := modelprovider.GetRegistry()
provider, err := registry.SelectProvider(generateConfig.ModelURL, generateConfig.Provider)
if err != nil {
return fmt.Errorf("failed to select provider: %w", err)
}

fmt.Printf("Using provider: %s\n", provider.Name())

// Check if user is authenticated with the provider
if err := provider.CheckAuth(); err != nil {
return fmt.Errorf("%s authentication check failed: %w", provider.Name(), err)
}

// Create a temporary directory for downloading the model
// Clean up the temporary directory after the function returns
tmpDir, err := os.MkdirTemp("", "modctl-model-downloads-*")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to consider expose this temp dir to user? because in some limited environment, user may only have write access to specific dirs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want users to specify which directory should be used to download models directly? As we are providing "" as the dir parameter's value to os.MkdirTemp , MkdirTemp uses the default directory for temporary files, as returned by TempDir.

if err != nil {
return fmt.Errorf("failed to create temporary directory: %w", err)
}
defer os.RemoveAll(tmpDir)

// Download the model
downloadPath, err := provider.DownloadModel(ctx, generateConfig.ModelURL, tmpDir)
if err != nil {
return fmt.Errorf("failed to download model from %s: %w", provider.Name(), err)
}

// Update workspace to the downloaded model path
generateConfig.Workspace = downloadPath
fmt.Printf("Using downloaded model at: %s\n", downloadPath)
}

fmt.Printf("Generating modelfile for %s\n", generateConfig.Workspace)
modelfile, err := modelfile.NewModelfileByWorkspace(generateConfig.Workspace, generateConfig)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions pkg/config/modelfile/modelfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type GenerateConfig struct {
ParamSize string
Precision string
Quantization string
ModelURL string
Provider string // Explicit provider for short-form URLs (e.g., "huggingface", "modelscope")
ExcludePatterns []string
}

Expand All @@ -56,6 +58,8 @@ func NewGenerateConfig() *GenerateConfig {
ParamSize: "",
Precision: "",
Quantization: "",
ModelURL: "",
Provider: "",
ExcludePatterns: []string{},
}
}
Expand Down
127 changes: 127 additions & 0 deletions pkg/modelprovider/huggingface/downloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright 2025 The CNAI Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package huggingface

import (
"fmt"
"io"
"net/url"
"os"
"os/exec"
"path/filepath"
"strings"
)

const (
huggingFaceBaseURL = "https://huggingface.co"
)

// parseModelURL parses a HuggingFace model URL and extracts the owner and repository name
func parseModelURL(modelURL string) (owner, repo string, err error) {
// Handle both full URLs and short-form repo names
modelURL = strings.TrimSpace(modelURL)

// Remove trailing slashes
modelURL = strings.TrimSuffix(modelURL, "/")

// If it's a full URL, parse it
if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") {
u, err := url.Parse(modelURL)
if err != nil {
return "", "", fmt.Errorf("invalid URL: %w", err)
}

// Expected format: https://huggingface.co/owner/repo
parts := strings.Split(strings.Trim(u.Path, "/"), "/")
if len(parts) < 2 {
return "", "", fmt.Errorf("invalid HuggingFace URL format, expected https://huggingface.co/owner/repo")
}

owner = parts[0]
repo = parts[1]
} else {
// Handle short-form like "owner/repo"
parts := strings.Split(modelURL, "/")
if len(parts) != 2 {
return "", "", fmt.Errorf("invalid model identifier, expected format: owner/repo")
}

owner = parts[0]
repo = parts[1]
}

if owner == "" || repo == "" {
return "", "", fmt.Errorf("owner and repository name cannot be empty")
}

return owner, repo, nil
}

// checkHuggingFaceAuth checks if the user is authenticated with HuggingFace
func checkHuggingFaceAuth() error {
// Try to find the HF token
token := os.Getenv("HF_TOKEN")
if token != "" {
return nil
}

// Check if the token file exists
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get user home directory: %w", err)
}

tokenPath := filepath.Join(homeDir, ".huggingface", "token")
if _, err := os.Stat(tokenPath); err == nil {
return nil
}

// Try using whoami command
if _, err := exec.LookPath("huggingface-cli"); err == nil {
cmd := exec.Command("huggingface-cli", "whoami")
cmd.Stdout = io.Discard
cmd.Stderr = io.Discard
if err := cmd.Run(); err == nil {
return nil
}
}

return fmt.Errorf("not authenticated with HuggingFace. Please run: huggingface-cli login")
}

// getToken retrieves the HuggingFace token from environment or token file
func getToken() (string, error) {
// First check environment variable
token := os.Getenv("HF_TOKEN")
if token != "" {
return token, nil
}

// Then check the token file
homeDir, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("failed to get user home directory: %w", err)
}

tokenPath := filepath.Join(homeDir, ".huggingface", "token")
data, err := os.ReadFile(tokenPath)
if err != nil {
return "", fmt.Errorf("failed to read token file: %w", err)
}

return strings.TrimSpace(string(data)), nil
}
Loading