Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions ControlR.Agent.Common/Services/AgentHubClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using ControlR.Libraries.Api.Contracts.Dtos.ServerApi;
using ControlR.Libraries.Shared.Helpers;
using ControlR.Libraries.Api.Contracts.Hubs.Clients;
using ControlR.Libraries.Api.Contracts.Enums;
using ControlR.Libraries.Ipc.Interfaces;
using ControlR.Libraries.Signalr.Client.Extensions;
using Microsoft.AspNetCore.SignalR;
Expand Down Expand Up @@ -916,6 +917,141 @@ public async Task<ValidateFilePathResponseDto> ValidateFilePath(ValidateFilePath
}
}

public async Task ExecuteScript(Guid executionId, string scriptContent, ShellType shellType, ScriptRunAs runAs)
{
if (runAs == ScriptRunAs.CurrentUser || runAs == ScriptRunAs.CurrentUserElevated)
{
Task.Run(async () =>
{
try
{
var sessions = await _desktopSessionProvider.GetActiveDesktopClients();
if (sessions == null || sessions.Length == 0)
{
_logger.LogWarning("No active user session found to execute script {ExecutionId} as current user.", executionId);
await _hubConnection.Server.SendScriptOutput(executionId, string.Empty, "Agent Error: No active user session found to execute script as current user." + Environment.NewLine, true, -1);
return;
}

var session = sessions[0];
if (!_ipcServerStore.TryGetServer(session.ProcessId, out var ipcServer))
{
_logger.LogWarning("No IPC server found for process ID {ProcessId} for script {ExecutionId}.", session.ProcessId, executionId);
await _hubConnection.Server.SendScriptOutput(executionId, string.Empty, $"Agent Error: No IPC connection to interactive user session (Process {session.ProcessId})." + Environment.NewLine, true, -1);
return;
}

await ipcServer.Server.Client.ExecuteScript(new ExecuteScriptIpcDto(executionId, scriptContent, shellType, runAs));
}
catch (Exception ex)
{
_logger.LogError(ex, "Error forwarding script {ExecutionId} to desktop client", executionId);
await _hubConnection.Server.SendScriptOutput(executionId, string.Empty, $"Agent Error: {ex.Message}" + Environment.NewLine, true, -1);
}
}).Forget();
return;
}

Task.Run(async () =>
{
string? tempFilePath = null;
try
{
var ext = shellType switch
{
ShellType.PowerShell => ".ps1",
ShellType.Cmd => ".bat",
ShellType.Bash => ".sh",
_ => ".txt"
};
tempFilePath = Path.Combine(Path.GetTempPath(), $"controlr_script_{executionId}{ext}");
await File.WriteAllTextAsync(tempFilePath, scriptContent);

string fileName;
string arguments;

if (shellType == ShellType.PowerShell)
{
fileName = _systemEnvironment.IsWindows() ? "powershell.exe" : "pwsh";
arguments = $"-NoProfile -NonInteractive -ExecutionPolicy Bypass -File \"{tempFilePath}\"";
}
else if (shellType == ShellType.Cmd)
{
fileName = "cmd.exe";
arguments = $"/c \"{tempFilePath}\"";
}
else // Bash
{
fileName = "/bin/bash";
arguments = $"\"{tempFilePath}\"";
}

var startInfo = new ProcessStartInfo
{
FileName = fileName,
Arguments = arguments,
RedirectStandardOutput = true,
RedirectStandardError = true,
UseShellExecute = false,
CreateNoWindow = true
};

using var process = new Process { StartInfo = startInfo };

process.OutputDataReceived += async (sender, e) =>
{
if (e.Data != null)
{
await _hubConnection.Server.SendScriptOutput(executionId, e.Data + Environment.NewLine, string.Empty, false, null);
}
};

process.ErrorDataReceived += async (sender, e) =>
{
if (e.Data != null)
{
await _hubConnection.Server.SendScriptOutput(executionId, string.Empty, e.Data + Environment.NewLine, false, null);
}
};

process.Start();
process.BeginOutputReadLine();
process.BeginErrorReadLine();

var completed = await process.WaitForExitAsync().WaitAsync(TimeSpan.FromMinutes(5)).ContinueWith(t => t.IsCompletedSuccessfully);

if (!completed)
{
process.Kill(true);
await _hubConnection.Server.SendScriptOutput(executionId, string.Empty, "Script execution timed out.", true, -1);
}
else
{
await _hubConnection.Server.SendScriptOutput(executionId, string.Empty, string.Empty, true, process.ExitCode);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error executing script {ExecutionId}", executionId);
await _hubConnection.Server.SendScriptOutput(executionId, string.Empty, $"Agent Error: {ex.Message}", true, -1);
}
finally
{
if (tempFilePath != null && File.Exists(tempFilePath))
{
try
{
File.Delete(tempFilePath);
}
catch
{
// Ignore
}
}
}
}).Forget();
}

private async Task<CheckOsPermissionsResponseIpcDto> EnsureDesktopClientPermissionGranted(
IDesktopClientRpcService desktopClient,
int targetProcessId,
Expand Down
13 changes: 13 additions & 0 deletions ControlR.Agent.Common/Services/AgentRpcService.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ControlR.Libraries.Ipc.Interfaces;
using ControlR.Libraries.Api.Contracts.Dtos.IpcDtos;

namespace ControlR.Agent.Common.Services;

Expand Down Expand Up @@ -32,4 +33,16 @@ public async Task<bool> SendChatResponse(ChatResponseIpcDto dto)
return false;
}
}

public async Task SendScriptOutput(ScriptOutputIpcDto dto)
{
try
{
await _hubConnection.Server.SendScriptOutput(dto.ExecutionId, dto.StdOut, dto.StdErr, dto.IsFinished, dto.ExitCode);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error while forwarding script output for execution {ExecutionId} to server", dto.ExecutionId);
}
}
}
13 changes: 13 additions & 0 deletions ControlR.ApiClient/ControlrApi.Endpoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,16 @@ public interface IServerVersionApi
{
Task<ApiResult<Version>> GetCurrentServerVersion(CancellationToken cancellationToken = default);
}

public interface IScriptsApi
{
Task<ApiResult<ScriptDto>> CreateScript(ScriptCreateRequestDto request, CancellationToken cancellationToken = default);
Task<ApiResult<ScriptDto[]>> GetAllScripts(CancellationToken cancellationToken = default);
Task<ApiResult<ScriptDto>> GetScript(Guid id, CancellationToken cancellationToken = default);
Task<ApiResult<ScriptDto>> UpdateScript(Guid id, ScriptCreateRequestDto request, CancellationToken cancellationToken = default);
Task<ApiResult> DeleteScript(Guid id, CancellationToken cancellationToken = default);
Task<ApiResult<ScriptExecutionDto[]>> ExecuteScript(Guid id, Guid[] deviceIds, ScriptRunAs runAs = ScriptRunAs.System, CancellationToken cancellationToken = default);
Task<ApiResult<ScriptExecutionDto[]>> ExecuteAdHocScript(ExecuteScriptRequestDto request, CancellationToken cancellationToken = default);
Task<ApiResult<ScriptExecutionDto>> GetScriptExecution(Guid executionId, CancellationToken cancellationToken = default);
Task<ApiResult<ScriptExecutionDto[]>> GetAllExecutions(CancellationToken cancellationToken = default);
}
5 changes: 4 additions & 1 deletion ControlR.ApiClient/ControlrApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public interface IControlrApi
IUsersApi Users { get; }
IUserServerSettingsApi UserServerSettings { get; }
IUserTagsApi UserTags { get; }
IScriptsApi Scripts { get; }
}

public partial class ControlrApi(
Expand Down Expand Up @@ -67,7 +68,8 @@ public partial class ControlrApi(
IUserTagsApi,
IUsersApi,
IAgentVersionApi,
IServerVersionApi
IServerVersionApi,
IScriptsApi
{
private readonly ControlrApiClientAuthState _authState = authState;
private readonly IBearerTokenRefresher _bearerTokenRefresher = bearerTokenRefresher;
Expand Down Expand Up @@ -98,6 +100,7 @@ public partial class ControlrApi(
public ITestEmailApi TestEmail => this;
public IUserPreferencesApi UserPreferences => this;
public IUserRolesApi UserRoles => this;
public IScriptsApi Scripts => this;
public IUsersApi Users => this;
public IUserServerSettingsApi UserServerSettings => this;
public IUserTagsApi UserTags => this;
Expand Down
84 changes: 84 additions & 0 deletions ControlR.ApiClient/Implementations/ControlrApi.Scripts.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
using System.Net.Http.Json;
using ControlR.Libraries.Api.Contracts.Constants;
using ControlR.Libraries.Api.Contracts.Dtos;
using ControlR.Libraries.Api.Contracts.Dtos.ServerApi;

using ControlR.Libraries.Api.Contracts.Enums;

namespace ControlR.ApiClient;

public partial class ControlrApi
{
async Task<ApiResult<ScriptDto>> IScriptsApi.CreateScript(ScriptCreateRequestDto request, CancellationToken cancellationToken)
{
return await ExecuteApiCall(async () =>
{
using var response = await _client.PostAsJsonAsync(HttpConstants.ScriptsEndpoint, request, cancellationToken);
await response.EnsureSuccessStatusCodeWithDetails();
return await response.Content.ReadFromJsonAsync<ScriptDto>(cancellationToken);
});
}

async Task<ApiResult<ScriptDto[]>> IScriptsApi.GetAllScripts(CancellationToken cancellationToken)
{
return await ExecuteApiCall(async () =>
await _client.GetFromJsonAsync<ScriptDto[]>(HttpConstants.ScriptsEndpoint, cancellationToken));
}

async Task<ApiResult<ScriptDto>> IScriptsApi.GetScript(Guid id, CancellationToken cancellationToken)
{
return await ExecuteApiCall(async () =>
await _client.GetFromJsonAsync<ScriptDto>($"{HttpConstants.ScriptsEndpoint}/{id}", cancellationToken));
}

async Task<ApiResult<ScriptDto>> IScriptsApi.UpdateScript(Guid id, ScriptCreateRequestDto request, CancellationToken cancellationToken)
{
return await ExecuteApiCall(async () =>
{
using var response = await _client.PutAsJsonAsync($"{HttpConstants.ScriptsEndpoint}/{id}", request, cancellationToken);
await response.EnsureSuccessStatusCodeWithDetails();
return await response.Content.ReadFromJsonAsync<ScriptDto>(cancellationToken);
});
}

async Task<ApiResult> IScriptsApi.DeleteScript(Guid id, CancellationToken cancellationToken)
{
return await ExecuteApiCall(async () =>
{
using var response = await _client.DeleteAsync($"{HttpConstants.ScriptsEndpoint}/{id}", cancellationToken);
await response.EnsureSuccessStatusCodeWithDetails();
});
}

async Task<ApiResult<ScriptExecutionDto[]>> IScriptsApi.ExecuteScript(Guid id, Guid[] deviceIds, ScriptRunAs runAs, CancellationToken cancellationToken)
{
return await ExecuteApiCall(async () =>
{
using var response = await _client.PostAsJsonAsync($"{HttpConstants.ScriptsEndpoint}/{id}/execute?runAs={runAs}", deviceIds, cancellationToken);
await response.EnsureSuccessStatusCodeWithDetails();
return await response.Content.ReadFromJsonAsync<ScriptExecutionDto[]>(cancellationToken);
});
}

async Task<ApiResult<ScriptExecutionDto[]>> IScriptsApi.ExecuteAdHocScript(ExecuteScriptRequestDto request, CancellationToken cancellationToken)
{
return await ExecuteApiCall(async () =>
{
using var response = await _client.PostAsJsonAsync($"{HttpConstants.ScriptsEndpoint}/execute-adhoc", request, cancellationToken);
await response.EnsureSuccessStatusCodeWithDetails();
return await response.Content.ReadFromJsonAsync<ScriptExecutionDto[]>(cancellationToken);
});
}

async Task<ApiResult<ScriptExecutionDto>> IScriptsApi.GetScriptExecution(Guid executionId, CancellationToken cancellationToken)
{
return await ExecuteApiCall(async () =>
await _client.GetFromJsonAsync<ScriptExecutionDto>($"{HttpConstants.ScriptsEndpoint}/executions/{executionId}", cancellationToken));
}

async Task<ApiResult<ScriptExecutionDto[]>> IScriptsApi.GetAllExecutions(CancellationToken cancellationToken)
{
return await ExecuteApiCall(async () =>
await _client.GetFromJsonAsync<ScriptExecutionDto[]>($"{HttpConstants.ScriptsEndpoint}/executions", cancellationToken));
}
}
Loading