Skip to content

Commit 8a70590

Browse files
committed
feat(run): add log message for successful detachment from SSH target
Signed-off-by: zhangwei <[email protected]>
1 parent c31b0fb commit 8a70590

File tree

4 files changed

+166
-110
lines changed

4 files changed

+166
-110
lines changed

pkg/app/run.go

Lines changed: 71 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,18 @@ import (
2929
"github.com/tensorchord/envd/pkg/app/telemetry"
3030
"github.com/tensorchord/envd/pkg/envd"
3131
"github.com/tensorchord/envd/pkg/home"
32-
"github.com/tensorchord/envd/pkg/ssh"
3332
sshconfig "github.com/tensorchord/envd/pkg/ssh/config"
3433
"github.com/tensorchord/envd/pkg/syncthing"
3534
"github.com/tensorchord/envd/pkg/types"
3635
"github.com/tensorchord/envd/pkg/util/fileutil"
37-
"github.com/tensorchord/envd/pkg/util/netutil"
3836
)
3937

4038
var CommandCreate = &cli.Command{
41-
Name: "run",
42-
Category: CategoryBasic,
43-
Aliases: []string{"c"},
44-
Usage: "Run the envd environment from the existing image",
45-
Hidden: false,
46-
Description: `run is only supported in envd-server runner currently`,
39+
Name: "run",
40+
Category: CategoryBasic,
41+
Aliases: []string{"c"},
42+
Usage: "Run the envd environment from the existing image",
43+
Hidden: false,
4744
Flags: []cli.Flag{
4845
&cli.StringFlag{
4946
Name: "image",
@@ -105,7 +102,7 @@ var CommandCreate = &cli.Command{
105102
},
106103
&cli.BoolFlag{
107104
Name: "sync",
108-
Usage: "Sync the local directory with the remote container",
105+
Usage: "Sync the local directory with the remote container (only supported in envd-server runner currently)",
109106
Value: false,
110107
},
111108
&cli.StringSliceFlag{
@@ -132,9 +129,9 @@ func run(clicontext *cli.Context) error {
132129
return err
133130
}
134131

135-
name := clicontext.String("name")
136-
if name == "" {
137-
name = strings.ToLower(randomdata.SillyName())
132+
environmentName := clicontext.String("name")
133+
if environmentName == "" {
134+
environmentName = strings.ToLower(randomdata.SillyName())
138135
}
139136
opt := envd.StartOptions{
140137
SshdHost: clicontext.String("host"),
@@ -144,7 +141,7 @@ func run(clicontext *cli.Context) error {
144141
NumCPU: clicontext.String("cpus"),
145142
NumGPU: clicontext.Int("gpu"),
146143
ShmSize: clicontext.Int("shm-size"),
147-
EnvironmentName: name,
144+
EnvironmentName: environmentName,
148145
}
149146
switch c.Runner {
150147
case types.RunnerTypeEnvdServer:
@@ -171,6 +168,8 @@ func run(clicontext *cli.Context) error {
171168
return err
172169
}
173170

171+
ctr := filepath.Base(opt.BuildContext)
172+
detach := clicontext.Bool("detach")
174173
logger := logrus.WithFields(logrus.Fields{
175174
"cmd": "run",
176175
"StartOptions": opt,
@@ -185,105 +184,83 @@ func run(clicontext *cli.Context) error {
185184
return errors.Wrap(err, "failed to get the ssh hostname")
186185
}
187186

188-
ac, err := home.GetManager().AuthGetCurrent()
189-
if err != nil {
190-
return errors.Wrap(err, "failed to get the auth information")
191-
}
192-
username, err := sshname.Username(ac.Name, res.Name)
193-
if err != nil {
194-
return errors.Wrap(err, "failed to get the username")
187+
var eo sshconfig.EntryOptions
188+
switch c.Runner {
189+
case types.RunnerTypeEnvdServer:
190+
ac, err := home.GetManager().AuthGetCurrent()
191+
if err != nil {
192+
return errors.Wrap(err, "failed to get the auth information")
193+
}
194+
username, err := sshname.Username(ac.Name, res.Name)
195+
if err != nil {
196+
return errors.Wrap(err, "failed to get the username")
197+
}
198+
eo = sshconfig.EntryOptions{
199+
Name: res.Name,
200+
IFace: hostname,
201+
Port: res.SSHPort,
202+
PrivateKeyPath: clicontext.Path("private-key"),
203+
EnableHostKeyCheck: false,
204+
EnableAgentForward: false,
205+
User: username,
206+
}
207+
case types.RunnerTypeDocker:
208+
eo, err = engine.GenerateSSHConfig(ctr, hostname,
209+
clicontext.Path("private-key"), res)
210+
if err != nil {
211+
return errors.Wrap(err, "failed to get the ssh entry")
212+
}
195213
}
196214

197-
eo := sshconfig.EntryOptions{
198-
Name: res.Name,
199-
IFace: hostname,
200-
Port: res.SSHPort,
201-
PrivateKeyPath: clicontext.Path("private-key"),
202-
EnableHostKeyCheck: false,
203-
EnableAgentForward: false,
204-
User: username,
205-
}
206215
if err = sshconfig.AddEntry(eo); err != nil {
207216
logger.WithError(err).
208217
Infof("failed to add entry %s to your SSH config file", res.Name)
209218
return errors.Wrap(err, "failed to add entry to your SSH config file")
210219
}
211220

212-
// TODO(gaocegege): Test why it fails.
213-
if !clicontext.Bool("detach") {
221+
if !detach {
214222
outputChannel := make(chan error)
215-
opt := ssh.DefaultOptions()
216-
opt.PrivateKeyPath = clicontext.Path("private-key")
217-
opt.Port = res.SSHPort
218-
opt.AgentForwarding = false
219-
opt.User = username
220-
opt.Server = hostname
221-
222-
sshClient, err := ssh.NewClient(opt)
223-
if err != nil {
224-
outputChannel <- errors.Wrap(err, "failed to create the ssh client")
225-
}
226-
227-
ports := res.Ports
228-
229-
for _, p := range ports {
230-
if p.Port == 2222 {
231-
continue
232-
}
233-
234-
// TODO(gaocegege): Use one remote port.
235-
localPort, err := netutil.GetFreePort()
236-
if err != nil {
237-
return errors.Wrap(err, "failed to get a free port")
238-
}
239-
localAddress := fmt.Sprintf("%s:%d", "localhost", localPort)
240-
remoteAddress := fmt.Sprintf("%s:%d", "localhost", p.Port)
241-
logger.Infof(`service "%s" is listening at %s\n`, p.Name, localAddress)
242-
go func() {
243-
if err := sshClient.LocalForward(localAddress, remoteAddress); err != nil {
244-
outputChannel <- errors.Wrap(err, "failed to forward to local port")
223+
if c.Runner == types.RunnerTypeEnvdServer {
224+
if clicontext.Bool("sync") {
225+
go func() {
226+
if err = engine.LocalForward(hostname, clicontext.Path("private-key"), res, syncthing.DefaultRemoteAPIAddress, syncthing.DefaultRemoteAPIAddress); err != nil {
227+
outputChannel <- errors.Wrap(err, "failed to forward to remote api port")
228+
}
229+
}()
230+
231+
go func() {
232+
syncthingRemoteAddr := fmt.Sprintf("127.0.0.1:%s", syncthing.ParsePortFromAddress(syncthing.DefaultRemoteDeviceAddress))
233+
if err = engine.LocalForward(hostname, clicontext.Path("private-key"), res, syncthingRemoteAddr, syncthingRemoteAddr); err != nil {
234+
outputChannel <- errors.Wrap(err, "failed to forward to remote port")
235+
}
236+
}()
237+
238+
go func() {
239+
syncthingLocalAddr := fmt.Sprintf("127.0.0.1:%s", syncthing.ParsePortFromAddress(syncthing.DefaultLocalDeviceAddress))
240+
if err = engine.RemoteForward(hostname, clicontext.Path("private-key"), res, syncthingLocalAddr, syncthingLocalAddr); err != nil {
241+
outputChannel <- errors.Wrap(err, "failed to forward to local port")
242+
}
243+
}()
244+
245+
localSyncthing, _, err := startSyncthing(res.Name)
246+
if err != nil {
247+
return errors.Wrap(err, "failed to start syncthing")
245248
}
246-
}()
247-
}
248-
249-
if clicontext.Bool("sync") {
250-
go func() {
251-
if err := sshClient.LocalForward(syncthing.DefaultRemoteAPIAddress, syncthing.DefaultRemoteAPIAddress); err != nil {
252-
outputChannel <- errors.Wrap(err, "failed to forward to remote api port")
253-
}
254-
}()
255-
256-
go func() {
257-
syncthingRemoteAddr := fmt.Sprintf("127.0.0.1:%s", syncthing.ParsePortFromAddress(syncthing.DefaultRemoteDeviceAddress))
258-
if err := sshClient.LocalForward(syncthingRemoteAddr, syncthingRemoteAddr); err != nil {
259-
outputChannel <- errors.Wrap(err, "failed to forward to remote port")
260-
}
261-
}()
262-
263-
go func() {
264-
syncthingLocalAddr := fmt.Sprintf("127.0.0.1:%s", syncthing.ParsePortFromAddress(syncthing.DefaultLocalDeviceAddress))
265-
if err := sshClient.RemoteForward(syncthingLocalAddr, syncthingLocalAddr); err != nil {
266-
outputChannel <- errors.Wrap(err, "failed to forward to local port")
267-
}
268-
}()
269-
270-
localSyncthing, _, err := startSyncthing(res.Name)
271-
if err != nil {
272-
return errors.Wrap(err, "failed to start syncthing")
249+
defer localSyncthing.StopLocalSyncthing()
273250
}
274-
defer localSyncthing.StopLocalSyncthing()
275-
276251
}
277252

278253
go func() {
279-
// TODO(gaocegege): Avoid the hard code.
280-
if err := sshClient.Attach(); err != nil {
281-
outputChannel <- errors.Wrap(err, "failed to attach to the container")
254+
if err = engine.Attach(ctr, hostname,
255+
clicontext.Path("private-key"), res, nil); err != nil {
256+
outputChannel <- errors.Wrap(err, "failed to attach to the ssh target")
282257
}
258+
logrus.Infof("Detached successfully. You can attach to the container with command `ssh %s.envd`\n",
259+
environmentName)
283260
outputChannel <- nil
284261
}()
285262

286-
if err := <-outputChannel; err != nil {
263+
if err = <-outputChannel; err != nil {
287264
return err
288265
}
289266
}

pkg/envd/docker.go

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -328,24 +328,58 @@ func (e dockerEngine) GenerateSSHConfig(name, iface, privateKeyPath string,
328328
return eo, nil
329329
}
330330

331-
func (e dockerEngine) Attach(name, iface, privateKeyPath string,
332-
startResult *StartResult, g ir.Graph) error {
331+
func (e dockerEngine) newSSHClient(server string, port int, privateKeyPath string) (ssh.Client, error) {
333332
opt := ssh.DefaultOptions()
334-
opt.Server = iface
333+
opt.Server = server
335334
opt.PrivateKeyPath = privateKeyPath
336-
opt.Port = startResult.SSHPort
335+
opt.Port = port
337336
sshClient, err := ssh.NewClient(opt)
338337
if err != nil {
339-
return errors.Wrap(err, "failed to create the ssh client")
338+
return nil, errors.Wrap(err, "failed to create the ssh client")
340339
}
341-
opt.Server = iface
340+
return sshClient, nil
341+
}
342342

343-
if err := sshClient.Attach(); err != nil {
343+
func (e dockerEngine) Attach(name, iface, privateKeyPath string,
344+
startResult *StartResult, g ir.Graph) error {
345+
sshClient, err := e.newSSHClient(iface, startResult.SSHPort, privateKeyPath)
346+
if err != nil {
347+
return err
348+
}
349+
defer sshClient.Close()
350+
351+
if err = sshClient.Attach(); err != nil {
344352
return errors.Wrap(err, "failed to attach to the container")
345353
}
346354
return nil
347355
}
348356

357+
func (e dockerEngine) LocalForward(iface, privateKeyPath string, startResult *StartResult, localAddress, targetAddress string) error {
358+
sshClient, err := e.newSSHClient(iface, startResult.SSHPort, privateKeyPath)
359+
if err != nil {
360+
return err
361+
}
362+
defer sshClient.Close()
363+
364+
if err = sshClient.LocalForward(localAddress, targetAddress); err != nil {
365+
return errors.Wrap(err, "failed to forward to local port")
366+
}
367+
return nil
368+
}
369+
370+
func (e dockerEngine) RemoteForward(iface, privateKeyPath string, startResult *StartResult, localAddress, targetAddress string) error {
371+
sshClient, err := e.newSSHClient(iface, startResult.SSHPort, privateKeyPath)
372+
if err != nil {
373+
return err
374+
}
375+
defer sshClient.Close()
376+
377+
if err = sshClient.RemoteForward(localAddress, targetAddress); err != nil {
378+
return errors.Wrap(err, "failed to forward to remote port")
379+
}
380+
return nil
381+
}
382+
349383
// StartEnvd creates the container for the given tag and container name.
350384
func (e dockerEngine) StartEnvd(ctx context.Context, so StartOptions) (*StartResult, error) {
351385
logger := logrus.WithFields(logrus.Fields{

pkg/envd/engine.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ type SSHClient interface {
5757
startResult *StartResult) (sshconfig.EntryOptions, error)
5858
Attach(name, iface, privateKeyPath string,
5959
startResult *StartResult, g ir.Graph) error
60+
LocalForward(iface, privateKeyPath string, startResult *StartResult, localAddress, targetAddress string) error
61+
RemoteForward(iface, privateKeyPath string, startResult *StartResult, localAddress, targetAddress string) error
6062
}
6163

6264
type ImageClient interface {

pkg/envd/envdserver.go

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,24 +152,32 @@ func (e envdServerEngine) GenerateSSHConfig(name, iface, privateKeyPath string,
152152
return eo, nil
153153
}
154154

155+
func (e envdServerEngine) newSSHClient(username, server string, port int, privateKeyPath string) (ssh.Client, error) {
156+
opt := ssh.DefaultOptions()
157+
opt.PrivateKeyPath = privateKeyPath
158+
opt.Port = port
159+
opt.AgentForwarding = false
160+
opt.User = username
161+
opt.Server = server
162+
sshClient, err := ssh.NewClient(opt)
163+
if err != nil {
164+
return nil, errors.Wrap(err, "failed to create the ssh client")
165+
}
166+
return sshClient, nil
167+
}
168+
155169
func (e envdServerEngine) Attach(name, iface, privateKeyPath string, startResult *StartResult, g ir.Graph) error {
156170
username, err := sshname.Username(e.Loginname, startResult.Name)
157171
if err != nil {
158172
return errors.Wrap(err, "failed to get the username")
159173
}
160174

161175
outputChannel := make(chan error)
162-
opt := ssh.DefaultOptions()
163-
opt.PrivateKeyPath = privateKeyPath
164-
opt.Port = startResult.SSHPort
165-
opt.AgentForwarding = false
166-
opt.User = username
167-
opt.Server = iface
168-
169-
sshClient, err := ssh.NewClient(opt)
176+
sshClient, err := e.newSSHClient(username, iface, startResult.SSHPort, privateKeyPath)
170177
if err != nil {
171-
outputChannel <- errors.Wrap(err, "failed to create the ssh client")
178+
outputChannel <- err
172179
}
180+
defer sshClient.Close()
173181

174182
ports := startResult.Ports
175183

@@ -207,6 +215,41 @@ func (e envdServerEngine) Attach(name, iface, privateKeyPath string, startResult
207215
return nil
208216
}
209217

218+
func (e *envdServerEngine) LocalForward(iface, privateKeyPath string, startResult *StartResult, localAddress, targetAddress string) error {
219+
username, err := sshname.Username(e.Loginname, startResult.Name)
220+
if err != nil {
221+
return errors.Wrap(err, "failed to get the username")
222+
}
223+
sshClient, err := e.newSSHClient(username, iface, startResult.SSHPort, privateKeyPath)
224+
if err != nil {
225+
return err
226+
}
227+
defer sshClient.Close()
228+
229+
if err = sshClient.LocalForward(localAddress, targetAddress); err != nil {
230+
return errors.Wrap(err, "failed to forward to local port")
231+
}
232+
return nil
233+
}
234+
235+
func (e *envdServerEngine) RemoteForward(iface, privateKeyPath string, startResult *StartResult, localAddress, targetAddress string) error {
236+
username, err := sshname.Username(e.Loginname, startResult.Name)
237+
if err != nil {
238+
return errors.Wrap(err, "failed to get the username")
239+
}
240+
241+
sshClient, err := e.newSSHClient(username, iface, startResult.SSHPort, privateKeyPath)
242+
if err != nil {
243+
return err
244+
}
245+
defer sshClient.Close()
246+
247+
if err = sshClient.RemoteForward(localAddress, targetAddress); err != nil {
248+
return errors.Wrap(err, "failed to forward to remote port")
249+
}
250+
return nil
251+
}
252+
210253
func (e *envdServerEngine) ListEnvRuntimeGraph(ctx context.Context, env string) (*ir.RuntimeGraph, error) {
211254
resp, err := e.EnvironmentGet(ctx, env)
212255
if err != nil {

0 commit comments

Comments
 (0)