Skip to content

Commit 59e3381

Browse files
committed
fix: ssh restart command
1 parent e120609 commit 59e3381

File tree

4 files changed

+79
-42
lines changed

4 files changed

+79
-42
lines changed

controller/thymis_controller/crud/deployment_info.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,15 @@ def get_first_device_host_by_config_id(session: Session, config_id: str) -> str
154154
return di.reachable_deployed_host if di else None
155155

156156

157-
def get_first_by_config_id(
158-
session: Session, config_id: str
159-
) -> db_models.DeploymentInfo | None:
157+
def get_first_by_config_id(session: Session, config_id: str):
160158
return (
161159
session.query(db_models.DeploymentInfo)
162160
.filter(db_models.DeploymentInfo.deployed_config_id == config_id)
163161
.first()
164162
)
165163

166164

167-
def get_by_config_id(
168-
session: Session, config_id: str
169-
) -> db_models.DeploymentInfo | None:
165+
def get_by_config_id(session: Session, config_id: str):
170166
return (
171167
session.query(db_models.DeploymentInfo)
172168
.filter(db_models.DeploymentInfo.deployed_config_id == config_id)

controller/thymis_controller/models/task.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,14 @@ class BuildDeviceImageTaskSubmission(BaseModel):
207207

208208
class SSHCommandTaskSubmission(BaseModel):
209209
type: Literal["ssh_command_task"] = "ssh_command_task"
210+
controller_access_client_endpoint: str
211+
deployment_info_id: uuid.UUID
212+
access_client_token: str
213+
deployment_public_key: str
214+
ssh_key_path: str
210215
target_user: str
211-
target_host: str
212216
target_port: int
213217
command: str
214-
ssh_key_path: str
215-
ssh_known_hosts_path: str
216218

217219

218220
class RunNixOSVMTaskSubmission(BaseModel):

controller/thymis_controller/routers/api_action.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
2+
import random
23
import uuid
34

45
from fastapi import APIRouter, Depends, HTTPException, Query, Response
56
from fastapi.responses import FileResponse, RedirectResponse
67
from thymis_agent import agent
78
from thymis_controller import crud, dependencies, models
89
from thymis_controller.config import global_settings
10+
from thymis_controller.crud.agent_token import create_access_client_token
911
from thymis_controller.dependencies import (
1012
DBSessionAD,
1113
NetworkRelayAD,
@@ -165,22 +167,36 @@ async def restart_device(
165167
identifier: str,
166168
db_session: DBSessionAD,
167169
task_controller: TaskControllerAD,
168-
project: ProjectAD,
170+
network_relay: NetworkRelayAD,
169171
user_session_id: UserSessionIDAD,
170172
):
171-
for target_host in crud.deployment_info.get_by_config_id(db_session, identifier):
172-
task_controller.submit(
173-
models.SSHCommandTaskSubmission(
174-
target_host=target_host.reachable_deployed_host,
175-
target_user="root",
176-
target_port=22,
177-
command="reboot",
178-
ssh_key_path=str(global_settings.PROJECT_PATH / "id_thymis"),
179-
ssh_known_hosts_path=str(project.known_hosts_path),
180-
),
181-
user_session_id=user_session_id,
182-
db_session=db_session,
183-
)
173+
for deployment_info in crud.deployment_info.get_by_config_id(
174+
db_session, identifier
175+
):
176+
if network_relay.public_key_to_connection_id.get(
177+
deployment_info.ssh_public_key
178+
):
179+
access_client_token = random.randbytes(32).hex()
180+
task = task_controller.submit(
181+
models.SSHCommandTaskSubmission(
182+
controller_access_client_endpoint=task_controller.access_client_endpoint,
183+
deployment_info_id=deployment_info.id,
184+
access_client_token=access_client_token,
185+
deployment_public_key=deployment_info.ssh_public_key,
186+
ssh_key_path=str(global_settings.PROJECT_PATH / "id_thymis"),
187+
target_user="root",
188+
target_port=22,
189+
command="reboot",
190+
),
191+
user_session_id=user_session_id,
192+
db_session=db_session,
193+
)
194+
create_access_client_token(
195+
db_session,
196+
deployment_info_id=deployment_info.id,
197+
token=access_client_token,
198+
deploy_device_task_id=task.id,
199+
)
184200

185201

186202
@router.head("/download-image")

controller/thymis_controller/task/worker.py

+42-19
Original file line numberDiff line numberDiff line change
@@ -533,26 +533,49 @@ def ssh_command_task(
533533
task_data = task.data
534534
assert task_data.type == "ssh_command_task"
535535

536-
returncode = run_command(
537-
task,
538-
conn,
539-
process_list,
540-
[
541-
"ssh",
542-
f"-o UserKnownHostsFile={task_data.ssh_known_hosts_path}",
543-
"-o StrictHostKeyChecking=yes",
544-
"-o ConnectTimeout=10",
545-
f"-i {task_data.ssh_key_path}",
546-
f"-p {task_data.target_port}",
547-
f"{task_data.target_user}@{task_data.target_host}",
548-
task_data.command,
549-
],
550-
)
536+
with tempfile.TemporaryDirectory() as tmpdir:
537+
# write deployment_public_key to tmpfile
538+
hostfile_path = f"{tmpdir}/known_hosts"
539+
with open(hostfile_path, "w", encoding="utf-8") as hostfile:
540+
hostfile.write(f"localhost {task_data.deployment_public_key}\n")
541+
hostfile.flush()
551542

552-
if returncode == 0:
553-
report_task_finished(task, conn)
554-
else:
555-
report_task_finished(task, conn, False, "SSH command failed")
543+
returncode = run_command(
544+
task,
545+
conn,
546+
process_list,
547+
[
548+
"ssh",
549+
"-i",
550+
f"{task_data.ssh_key_path}",
551+
"-o",
552+
f"UserKnownHostsFile={hostfile.name}",
553+
"-o",
554+
"StrictHostKeyChecking=yes",
555+
"-o",
556+
"PasswordAuthentication=no",
557+
"-o",
558+
"KbdInteractiveAuthentication=no",
559+
"-o",
560+
"ConnectTimeout=10",
561+
"-o",
562+
"BatchMode=yes",
563+
"-o",
564+
f"ProxyCommand={(os.getenv('PYTHONENV')+'/bin/python') if ('PYTHONENV' in os.environ) else 'python' } -m thymis_controller.access_client {task_data.controller_access_client_endpoint} {task_data.deployment_info_id}",
565+
f"{task_data.target_user}@localhost",
566+
task_data.command,
567+
],
568+
env={
569+
"PATH": os.getenv("PATH"),
570+
"HTTP_NETWORK_RELAY_SECRET": task_data.access_client_token,
571+
},
572+
cwd=tmpdir,
573+
)
574+
575+
if returncode == 0:
576+
report_task_finished(task, conn)
577+
else:
578+
report_task_finished(task, conn, False, "SSH command failed")
556579

557580

558581
SUPPORTED_TASK_TYPES = {

0 commit comments

Comments
 (0)