diff --git a/pkg-r/DESCRIPTION b/pkg-r/DESCRIPTION index d602bde47..17473a272 100644 --- a/pkg-r/DESCRIPTION +++ b/pkg-r/DESCRIPTION @@ -16,9 +16,11 @@ Depends: R (>= 4.1.0) Imports: bslib, + callr, DBI, duckdb, ellmer, + ggplot2, glue, htmltools, jsonlite, diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index 27601e94a..653838638 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -28,24 +28,23 @@ #' #' @export querychat_init <- function( - df, - ..., - table_name = deparse(substitute(df)), - greeting = NULL, - data_description = NULL, - extra_instructions = NULL, - prompt_template = NULL, - system_prompt = querychat_system_prompt( df, - table_name, - # By default, pass through any params supplied to querychat_init() ..., - data_description = data_description, - extra_instructions = extra_instructions, - prompt_template = prompt_template - ), - create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o") -) { + table_name = deparse(substitute(df)), + greeting = NULL, + data_description = NULL, + extra_instructions = NULL, + prompt_template = NULL, + system_prompt = querychat_system_prompt( + df, + table_name, + # By default, pass through any params supplied to querychat_init() + ..., + data_description = data_description, + extra_instructions = extra_instructions, + prompt_template = prompt_template + ), + create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o")) { is_table_name_ok <- is.character(table_name) && length(table_name) == 1 && grepl("^[a-zA-Z][a-zA-Z0-9_]*$", table_name, perl = TRUE) @@ -139,7 +138,8 @@ querychat_ui <- function(id) { htmltools::tagList( # TODO: Make this into a proper HTML dependency shiny::includeCSS(system.file("www", "styles.css", package = "querychat")), - shinychat::chat_ui(ns("chat"), height = "100%", fill = TRUE) + shinychat::chat_ui(ns("chat"), height = "100%", fill = TRUE), + shiny::plotOutput(ns("llm_plot"), height = 300) ) } @@ -191,6 +191,47 @@ querychat_server <- function(id, querychat_config) { session = session ) } + plot_code <- shiny::reactiveVal(NULL) + # Preload the conversation with the system prompt. These are instructions for + # the chat model, and must not be shown to the end user. + chat <- create_chat_func(system_prompt = system_prompt) + output$llm_plot <- shiny::renderPlot({ + code <- plot_code() + if (is.null(code) || !nzchar(code)) { + return(NULL) + } + df <- filtered_df() + forbidden <- c("system", "file", "unlink", "assign", "library", "require") + if (any(sapply(forbidden, grepl, code))) { + stop("Forbidden function detected in plot code.") + } + res <- tryCatch( + { + callr::r( + function(code, df) { + p <- eval(parse(text = code)) + if (!inherits(p, "ggplot")) stop("Code did not return a ggplot object.") + p # return the ggplot object + }, + args = list(code = code, df = df), + show = TRUE, + stdout = TRUE, + stderr = TRUE, + ) + }, + error = function(e) { + message( + "Plot error: ", e$message, "\n", + "Code: ", code, "\n", + ) + plot.new() + text(0.5, 0.5, "Plot error. See R console for details.") + return(NULL) + } + ) + if (inherits(res, "ggplot")) print(res) + invisible(res) + }) # Modifies the data presented in the data dashboard, based on the given SQL # query, and also updates the title. @@ -219,6 +260,18 @@ querychat_server <- function(id, querychat_config) { } } + update_plot <- function(ggplot_code) { + plot_code(ggplot_code) + append_output("\n```r\n", ggplot_code, "\n```\n\n") + } + chat$register_tool(ellmer::tool( + update_plot, + "Updates the plot displayed in the data dashboard, based on the given ggplot code.", + ggplot_code = ellmer::type_string( + "A string containing R code that generates a ggplot object." + ) + )) + # Perform a SQL query on the data, and return the results as JSON. # @param query A DuckDB SQL query; must be a SELECT statement. # @return The results of the query as a JSON string. @@ -242,9 +295,6 @@ querychat_server <- function(id, querychat_config) { df |> jsonlite::toJSON(auto_unbox = TRUE) } - # Preload the conversation with the system prompt. These are instructions for - # the chat model, and must not be shown to the end user. - chat <- create_chat_func(system_prompt = system_prompt) chat$register_tool(ellmer::tool( update_dashboard, "Modifies the data presented in the data dashboard, based on the given SQL query, and also updates the title.", diff --git a/pkg-r/inst/prompt/prompt.md b/pkg-r/inst/prompt/prompt.md index 9ed80f43e..2ff7128f5 100644 --- a/pkg-r/inst/prompt/prompt.md +++ b/pkg-r/inst/prompt/prompt.md @@ -94,4 +94,32 @@ If you find yourself offering example questions to the user as part of your resp * `percentile_cont` and `percentile_disc` are "ordered set" aggregate functions. These functions are specified using the WITHIN GROUP (ORDER BY sort_expression) syntax, and they are converted to an equivalent aggregate function that takes the ordering expression as the first argument. For example, `percentile_cont(fraction) WITHIN GROUP (ORDER BY column [(ASC|DESC)])` is equivalent to `quantile_cont(column, fraction ORDER BY column [(ASC|DESC)])`. +## Task: Plotting with ggplot2 + +You can create and update plots in the dashboard using the `update_plot` tool. This tool takes a string of R code that generates a ggplot2 plot using the data frame `df` (which contains the currently filtered data). The code you provide will be evaluated and the resulting plot will be displayed in the dashboard. + +* Always use valid R code that creates a ggplot2 plot and assigns it as the last expression (no assignment needed, just return the plot object). +* The data frame available for plotting is named `df`. +* Do not attempt to retrieve or manipulate data outside of `df`. +* Only use plotting code that is safe and reproducible. + +## Plotting guardrails + +When generating R code for plotting, you must never use or reference any of the following functions or statements: `system`, `file`, `unlink`, `assign`, `library`, `require`, or any function that accesses the file system, environment, or external resources. Only use functions from `ggplot2` and the provided data frame `df`. Any attempt to use forbidden functions will result in an error and your code will not be executed. + +Example of plotting: + +> [User] +> Show me a scatterplot of x vs y. +> [/User] +> [ToolCall] +> update_plot({ggplot_code: "ggplot2::ggplot(df, ggplot2::aes(x = x, y = y)) + ggplot2::geom_point()"}) +> [/ToolCall] +> [ToolResponse] +> null +> [/ToolResponse] +> [Assistant] +> Here is a scatterplot of x vs y. +> [/Assistant] + {{extra_instructions}} \ No newline at end of file