From c96022c7b2c5224b42750b80d15a70aa456e8a4c Mon Sep 17 00:00:00 2001 From: Jonathan Berrisch Date: Fri, 23 May 2025 23:45:45 +0200 Subject: [PATCH] Add interactive plot to the slides --- 25_07_phd_defense/app.qmd | 28 ++- 25_07_phd_defense/assets/make_knots_data.R | 2 +- 25_07_phd_defense/index.qmd | 230 ++++++++++++++++++++- 3 files changed, 255 insertions(+), 5 deletions(-) diff --git a/25_07_phd_defense/app.qmd b/25_07_phd_defense/app.qmd index 9b956dd..5ae4e20 100644 --- a/25_07_phd_defense/app.qmd +++ b/25_07_phd_defense/app.qmd @@ -86,7 +86,29 @@ chart = { .style('width', '100%'); sliderCont.append('span').text(d => (d.label.match(/Sigma|Tailweight/) ? d.get() : d.get())) .style('font-size','20px'); - // Build SVG + + // Add Reset button to clear all sliders to their defaults + controlsContainer.append('button') + .text('Reset') + .style('font-size', '20px') + .style('align-self', 'center') + .style('margin-left', 'auto') + .on('click', () => { + // reset state vars + selectedMu = 0.5; + selectedSig = 1; + selectedNonc = 0; + selectedTailw = 1; + // update input positions + sliderCont.selectAll('input').property('value', d => d.get()); + // update displayed labels + sliderCont.selectAll('span') + .text(d => d.label.match(/Sigma|Tailweight/) ? (2**d.get()) : d.get()); + // redraw chart + updateChart(filteredData()); + }); + + // Build SVG const width = 800; const height = 400; const margin = {top: 40, right: 20, bottom: 40, left: 40}; @@ -105,7 +127,7 @@ chart = { .range([0, innerWidth]); const y = d3.scaleLinear() - .domain([0, 0.7]) + .domain([0, 1]) .range([innerHeight, 0]); // Create a color scale for the basis functions @@ -184,7 +206,7 @@ chart = { const line = d3.line() .x(d => x(d.x)) .y(d => y(d.y)) - .curve(d3.curveBasis); + .curve(d3.curveLinear); // Group to contain the basis function lines const linesGroup = g.append("g") diff --git a/25_07_phd_defense/assets/make_knots_data.R b/25_07_phd_defense/assets/make_knots_data.R index 5106bc8..caf24df 100644 --- a/25_07_phd_defense/assets/make_knots_data.R +++ b/25_07_phd_defense/assets/make_knots_data.R @@ -8,7 +8,7 @@ library(readr) # Creating faceted plots for different knot values and mu values # Create a function to generate the data for a given number of knots and mu value generate_basis_data <- function(num_knots, mu_value, sig_value, nonc_value, tailw_value, deg_value) { - grid <- seq(from = 0.01, to = 0.99, length.out = 50) + grid <- seq(from = 0.01, to = 0.99, length.out = 99) # Use provided degree B <- profoc:::make_basis_matrix(grid, profoc::make_knots( diff --git a/25_07_phd_defense/index.qmd b/25_07_phd_defense/index.qmd index 97c29f0..7df7037 100644 --- a/25_07_phd_defense/index.qmd +++ b/25_07_phd_defense/index.qmd @@ -2117,6 +2117,233 @@ We use `Rcpp` modules to expose a class to R ## Profoc - B-Spline Basis +::: {.panel-tabset} + +## Knot Placement Illustration + +```{ojs} +d3 = require("d3@7") +``` + +```{ojs} +bsplineData = FileAttachment("assets/mcrps_learning/basis_functions.csv").csv({ typed: true }) +``` + +```{ojs} +function updateChartInner(g, x, y, linesGroup, color, line, data) { + // Update axes with transitions + x.domain([0, d3.max(data, d => d.x)]); + g.select(".x-axis").transition().duration(1500).call(d3.axisBottom(x).ticks(10)); + y.domain([0, d3.max(data, d => d.y)]); + g.select(".y-axis").transition().duration(1500).call(d3.axisLeft(y).ticks(5)); + + // Group data by basis function + const dataByFunction = Array.from(d3.group(data, d => d.b)); + const keyFn = d => d[0]; + + // Update basis function lines + const u = linesGroup.selectAll("path").data(dataByFunction, keyFn); + u.join( + enter => enter.append("path").attr("fill","none").attr("stroke-width",3) + .attr("stroke", (_, i) => color(i)).attr("d", d => line(d[1].map(pt => ({x: pt.x, y: 0})))) + .style("opacity",0), + update => update, + exit => exit.transition().duration(1000).style("opacity",0).remove() + ) + .transition().duration(1000) + .attr("d", d => line(d[1])) + .attr("stroke", (_, i) => color(i)) + .style("opacity",1); +} + +chart = { + // State variables for selected parameters + let selectedMu = 0.5; + let selectedSig = 1; + let selectedNonc = 0; + let selectedTailw = 1; + const filteredData = () => bsplineData.filter(d => + Math.abs(selectedMu - d.mu) < 0.001 && + d.sig === selectedSig && + d.nonc === selectedNonc && + d.tailw === selectedTailw + ); + const container = d3.create("div") + .style("max-width", "none") + .style("width", "100%");; + const controlsContainer = container.append("div") + .style("display", "flex") + .style("gap", "20px"); + // slider controls + const sliders = [ + { label: 'Mu', get: () => selectedMu, set: v => selectedMu = v, min: 0.1, max: 0.9, step: 0.2 }, + { label: 'Sigma', get: () => Math.log2(selectedSig), set: v => selectedSig = 2 ** v, min: -2, max: 2, step: 1 }, + { label: 'Noncentrality', get: () => selectedNonc, set: v => selectedNonc = v, min: -4, max: 4, step: 2 }, + { label: 'Tailweight', get: () => Math.log2(selectedTailw), set: v => selectedTailw = 2 ** v, min: -2, max: 2, step: 1 } + ]; + // Build slider controls with D3 data join + const sliderCont = controlsContainer.selectAll('div').data(sliders).join('div') + .style('display','flex').style('align-items','center').style('gap','10px') + .style('flex','1').style('min-width','0px'); + sliderCont.append('label').text(d => d.label + ':').style('font-size','20px'); + sliderCont.append('input') + .attr('type','range').attr('min', d => d.min).attr('max', d => d.max).attr('step', d => d.step) + .property('value', d => d.get()) + .on('input', function(event, d) { + const val = +this.value; d.set(val); + d3.select(this.parentNode).select('span').text(d.label.match(/Sigma|Tailweight/) ? 2**val : val); + updateChart(filteredData()); + }) + .style('width', '100%'); + sliderCont.append('span').text(d => (d.label.match(/Sigma|Tailweight/) ? d.get() : d.get())) + .style('font-size','20px'); + + // Add Reset button to clear all sliders to their defaults + controlsContainer.append('button') + .text('Reset') + .style('font-size', '20px') + .style('align-self', 'center') + .style('margin-left', 'auto') + .on('click', () => { + // reset state vars + selectedMu = 0.5; + selectedSig = 1; + selectedNonc = 0; + selectedTailw = 1; + // update input positions + sliderCont.selectAll('input').property('value', d => d.get()); + // update displayed labels + sliderCont.selectAll('span') + .text(d => d.label.match(/Sigma|Tailweight/) ? (2**d.get()) : d.get()); + // redraw chart + updateChart(filteredData()); + }); + + // Build SVG + const width = 800; + const height = 400; + const margin = {top: 40, right: 20, bottom: 40, left: 40}; + const innerWidth = width - margin.left - margin.right; + const innerHeight = height - margin.top - margin.bottom; + + // Set controls container width to match SVG plot width + controlsContainer.style("max-width", "none").style("width", "100%"); + // Distribute each control evenly and make sliders full-width + controlsContainer.selectAll("div").style("flex", "1").style("min-width", "0px"); + controlsContainer.selectAll("input").style("width", "100%").style("box-sizing", "border-box"); + + // Create scales + const x = d3.scaleLinear() + .domain([0, 1]) + .range([0, innerWidth]); + + const y = d3.scaleLinear() + .domain([0, 1]) + .range([innerHeight, 0]); + + // Create a color scale for the basis functions + const color = d3.scaleOrdinal(d3.schemeCategory10); + + // Create SVG + const svg = d3.create("svg") + .attr("width", "100%") + .attr("height", "auto") + .attr("viewBox", [0, 0, width, height]) + .attr("preserveAspectRatio", "xMidYMid meet") + .attr("style", "max-width: 100%; height: auto;"); + + // Add chart title + // svg.append("text") + // .attr("class", "chart-title") + // .attr("x", width / 2) + // .attr("y", 20) + // .attr("text-anchor", "middle") + // .attr("font-size", "20px") + // .attr("font-weight", "bold"); + + // Create the chart group + const g = svg.append("g") + .attr("transform", `translate(${margin.left},${margin.top})`); + + // Add axes + const xAxis = g.append("g") + .attr("transform", `translate(0,${innerHeight})`) + .attr("class", "x-axis") + .call(d3.axisBottom(x).ticks(10)) + .style("font-size", "20px"); + + const yAxis = g.append("g") + .attr("class", "y-axis") + .call(d3.axisLeft(y).ticks(5)) + .style("font-size", "20px"); + + // Add axis labels + // g.append("text") + // .attr("x", innerWidth / 2) + // .attr("y", innerHeight + 35) + // .attr("text-anchor", "middle") + // .text("x"); + + // g.append("text") + // .attr("transform", "rotate(-90)") + // .attr("x", -innerHeight / 2) + // .attr("y", -30) + // .attr("text-anchor", "middle") + // .text("y"); + + // Add a horizontal line at y = 0 + g.append("line") + .attr("x1", 0) + .attr("x2", innerWidth) + .attr("y1", y(0)) + .attr("y2", y(0)) + .attr("stroke", "#000") + .attr("stroke-opacity", 0.2); + + // Add gridlines + g.append("g") + .attr("class", "grid-lines") + .selectAll("line") + .data(y.ticks(5)) + .join("line") + .attr("x1", 0) + .attr("x2", innerWidth) + .attr("y1", d => y(d)) + .attr("y2", d => y(d)) + .attr("stroke", "#ccc") + .attr("stroke-opacity", 0.5); + + // Create a line generator + const line = d3.line() + .x(d => x(d.x)) + .y(d => y(d.y)) + .curve(d3.curveLinear); + + // Group to contain the basis function lines + const linesGroup = g.append("g") + .attr("class", "basis-functions"); + + // Store the current basis functions for transition + let currentBasisFunctions = new Map(); + + // Function to update the chart with new data + function updateChart(data) { + updateChartInner(g, x, y, linesGroup, color, line, data); + } + + // Store the update function + svg.node().update = updateChart; + + // Initial render + updateChart(filteredData()); + + container.node().appendChild(svg.node()); + return container.node(); +} +``` + +## Knot Placement Details + :::: {.columns} ::: {.column width="48%"} @@ -2151,12 +2378,13 @@ TODO: Add actual algorithm to backup slides ::: {.column width="48%"} -TODO ::: :::: +:::: + ## Wrap-Up :::: {.columns}