Files
PHD-Presentation/assets/crps_learning/weights_plot/index.qmd

202 lines
5.6 KiB
Plaintext

---
title: "Knots-Demo"
date: 2025-07-10
format:
revealjs:
embed-resources: true
execute:
daemon: false
highlight-style: github
---
```{ojs}
d3 = require("d3@7")
```
```{ojs}
bsplineData = FileAttachment("basis_functions.csv").csv({ typed: true })
```
```{ojs}
function updateChartInner(g, x, y, linesGroup, color, line, data) {
// Update axes with transitions
x.domain(d3.extent(data, d => d.x));
g.select(".x-axis").transition().duration(1500).call(d3.axisBottom(x).ticks(10));
y.domain([0, d3.max(data, d => d.y)]);
g.select(".y-axis").transition().duration(1500).call(d3.axisLeft(y).ticks(5));
// Group data by basis function
const dataByFunction = Array.from(d3.group(data, d => d.b));
const keyFn = d => d[0];
// Update basis function lines
const u = linesGroup.selectAll("path").data(dataByFunction, keyFn);
u.join(
enter => enter.append("path").attr("fill","none").attr("stroke-width",3)
.attr("stroke", (_, i) => color(i)).attr("d", d => line(d[1].map(pt => ({x: pt.x, y: 0}))))
.style("opacity",0),
update => update,
exit => exit.transition().duration(1000).style("opacity",0).remove()
)
.transition().duration(1000)
.attr("d", d => line(d[1]))
.attr("stroke", (_, i) => color(i))
.style("opacity",1);
}
chart = {
// State variable for selected mu parameter
let selectedMu = 0.5;
const filteredData = () => bsplineData.filter(d =>
Math.abs(selectedMu - d.mu) < 0.001
);
const container = d3.create("div")
.style("max-width", "none")
.style("width", "100%");
const controlsContainer = container.append("div")
.style("display", "flex")
.style("gap", "20px")
.style("align-items", "center");
// Single slider control for mu
const sliderContainer = controlsContainer.append('div')
.style('display','flex')
.style('align-items','center')
.style('gap','10px')
.style('flex','1');
sliderContainer.append('label')
.text('Naive:')
.style('font-size','20px');
const muSlider = sliderContainer.append('input')
.attr('type','range')
.attr('min', 0)
.attr('max', 1)
.attr('step', 0.1)
.property('value', selectedMu)
.on('input', function(event) {
selectedMu = +this.value;
muDisplay.text(selectedMu.toFixed(1));
updateChart(filteredData());
})
.style('width', '100%');
const muDisplay = sliderContainer.append('span')
.text(selectedMu.toFixed(1))
.style('font-size','20px');
// Add Reset button
controlsContainer.append('button')
.text('Reset')
.style('font-size', '20px')
.style('align-self', 'center')
.style('margin-left', 'auto')
.on('click', () => {
selectedMu = 0.5;
muSlider.property('value', selectedMu);
muDisplay.text(selectedMu.toFixed(1));
updateChart(filteredData());
});
// Build SVG
const width = 1200;
const height = 450;
const margin = {top: 40, right: 20, bottom: 40, left: 40};
const innerWidth = width - margin.left - margin.right;
const innerHeight = height - margin.top - margin.bottom;
// Set controls container width to match SVG plot width
controlsContainer.style("max-width", "none").style("width", "100%");
// Distribute each control evenly and make sliders full-width
controlsContainer.selectAll("div").style("flex", "1").style("min-width", "0px");
controlsContainer.selectAll("input").style("width", "100%").style("box-sizing", "border-box");
// Create scales
const x = d3.scaleLinear()
.range([0, innerWidth]);
const y = d3.scaleLinear()
.range([innerHeight, 0]);
// Create a color scale for the basis functions
const color = d3.scaleOrdinal(["#80C684FF", "#FFD44EFF", "#D81A5FFF"]);
// Create SVG
const svg = d3.create("svg")
.attr("width", "100%")
.attr("height", "auto")
.attr("viewBox", [0, 0, width, height])
.attr("preserveAspectRatio", "xMidYMid meet")
.attr("style", "max-width: 100%; height: auto;");
// Create the chart group
const g = svg.append("g")
.attr("transform", `translate(${margin.left},${margin.top})`);
// Add axes
const xAxis = g.append("g")
.attr("transform", `translate(0,${innerHeight})`)
.attr("class", "x-axis")
.call(d3.axisBottom(x).ticks(10))
.style("font-size", "20px");
const yAxis = g.append("g")
.attr("class", "y-axis")
.call(d3.axisLeft(y).ticks(5))
.style("font-size", "20px");
// Add a horizontal line at y = 0
g.append("line")
.attr("x1", 0)
.attr("x2", innerWidth)
.attr("y1", y(0))
.attr("y2", y(0))
.attr("stroke", "#000")
.attr("stroke-opacity", 0.2);
// Add gridlines
g.append("g")
.attr("class", "grid-lines")
.selectAll("line")
.data(y.ticks(5))
.join("line")
.attr("x1", 0)
.attr("x2", innerWidth)
.attr("y1", d => y(d))
.attr("y2", d => y(d))
.attr("stroke", "#ccc")
.attr("stroke-opacity", 0.5);
// Create a line generator
const line = d3.line()
.x(d => x(d.x))
.y(d => y(d.y))
.curve(d3.curveBasis);
// Group to contain the basis function lines
const linesGroup = g.append("g")
.attr("class", "basis-functions");
// Store the current basis functions for transition
let currentBasisFunctions = new Map();
// Function to update the chart with new data
function updateChart(data) {
updateChartInner(g, x, y, linesGroup, color, line, data);
}
// Store the update function
svg.node().update = updateChart;
// Initial render
updateChart(filteredData());
container.node().appendChild(svg.node());
return container.node();
}
```