Skip to content

API Reference

PhasePlaneWidget

PhasePlaneWidget

Bases: AnyWidget

Interactive phase plane widget.

In Jupyter / VS Code the widget is backed by anywidget. All numerical work (ODE integration, fixed-point search, nullclines, sweeps) is done in the browser by the JS front-end, so interactivity is instantaneous.

For static sites use :meth:to_standalone_html to obtain a self-contained .html file that can be dropped into mkdocs, GitHub Pages, etc.

Source code in src/tvb_phaseplane/widget.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
class PhasePlaneWidget(anywidget.AnyWidget):
    """Interactive phase plane widget.

    In Jupyter / VS Code the widget is backed by anywidget.  All numerical
    work (ODE integration, fixed-point search, nullclines, sweeps) is done
    in the browser by the JS front-end, so interactivity is instantaneous.

    For static sites use :meth:`to_standalone_html` to obtain a self-contained
    ``.html`` file that can be dropped into mkdocs, GitHub Pages, etc.
    """

    # Inline JS/CSS so the widget is self-contained and standalone HTML exports work.
    _esm = (_STATIC_DIR / "widget.js").read_text()
    _css = (_STATIC_DIR / "widget.css").read_text()

    # ── Initial state (synced to JS front-end) ──
    model_name = traitlets.Unicode("wilson_cowan").tag(sync=True)
    params = traitlets.Dict({}).tag(sync=True)
    param_info = traitlets.Dict({}).tag(sync=True)
    state_names = traitlets.List(["x", "y"]).tag(sync=True)

    x0 = traitlets.Float(0.1).tag(sync=True)
    y0 = traitlets.Float(0.1).tag(sync=True)

    xlim = traitlets.List([-0.5, 1.5]).tag(sync=True)
    ylim = traitlets.List([-0.5, 1.5]).tag(sync=True)
    t_max = traitlets.Float(100.0).tag(sync=True)

    # Pre-computed data (populated by JS, kept for inspection / export)
    nullcline_x = traitlets.List([]).tag(sync=True)
    nullcline_y = traitlets.List([]).tag(sync=True)
    vector_field = traitlets.List([]).tag(sync=True)
    fixed_points = traitlets.List([]).tag(sync=True)
    trajectory = traitlets.List([]).tag(sync=True)

    sweep_results = traitlets.List([]).tag(sync=True)
    sweep_fixed_points = traitlets.List([]).tag(sync=True)
    sweep_running = traitlets.Bool(False).tag(sync=True)

    # Display toggles
    show_nullclines = traitlets.Bool(True).tag(sync=True)
    show_vector_field = traitlets.Bool(True).tag(sync=True)
    show_trajectory = traitlets.Bool(True).tag(sync=True)
    show_fixed_points = traitlets.Bool(True).tag(sync=True)

    # Integrator / noise
    integrator = traitlets.Unicode("rk4").tag(sync=True)
    noise_enable = traitlets.Bool(False).tag(sync=True)
    noise_sigma = traitlets.List([]).tag(sync=True)

    # Custom model specification (JSON-serialisable dict for JS)
    model_spec = traitlets.Dict(allow_none=True, default_value=None).tag(sync=True)

    # Display indices for multi-variable projections
    display = traitlets.List([0, 1]).tag(sync=True)

    # Clamped values for non-displayed state variables
    clamped = traitlets.List(default_value=None, allow_none=True).tag(sync=True)

    # Layout mode: "full" shows controls + phase plane + time series + sweep
    # "phase_plane" shows only the phase plane canvas (useful in notebooks / iframes)
    display_mode = traitlets.Unicode("full").tag(sync=True)

    _DISPLAY_MODES = frozenset({"full", "phase_plane"})

    @traitlets.validate("display_mode")
    def _validate_display_mode(self, proposal):
        v = proposal["value"]
        if v not in self._DISPLAY_MODES:
            raise traitlets.TraitError(
                f"display_mode must be one of {self._DISPLAY_MODES}, got {v!r}"
            )
        return v

    _model_instance = None

    def __init__(self, model=None, **kwargs):
        if model is not None:
            self._model_instance = model
            kwargs.setdefault("model_name", model.name)
        super().__init__(**kwargs)
        self._update_model()
        # Register handler for custom JS → Python messages (e.g. TikZ export)
        self.on_msg(self._on_custom_msg)

    def _get_model(self):
        from .models import MODEL_REGISTRY

        if self._model_instance is not None:
            return self._model_instance
        cls = MODEL_REGISTRY.get(self.model_name, MODEL_REGISTRY["wilson_cowan"])
        return cls()

    def _update_model(self):
        """Push model metadata to the JS front-end."""
        model = self._get_model()
        self.param_info = model.param_info
        self.state_names = model.state_names
        self.params = {k: v[2] for k, v in model.param_info.items()}
        self.xlim = model.default_xlim
        self.ylim = model.default_ylim

    @traitlets.observe("model_name")
    def _on_model_change(self, change):
        if self.model_name != "custom":
            self._update_model()

    def set_model_spec(self, spec: dict):
        """Load a custom model from a ``ModelSpec`` dict.

        Parameters
        ----------
        spec : dict
            JSON-serialisable model specification (see
            :meth:`ModelSpec.to_widget_state`).
        """
        self.model_spec = spec
        self.model_name = "custom"
        # Derive initial params / limits from the spec so the widget
        # has sensible defaults before JS takes over.
        params = {n: v["default"] for n, v in spec.get("parameters", {}).items()}
        self.params = params
        # Sync param_info for the existing slider infrastructure
        param_info = {}
        for n, v in spec.get("parameters", {}).items():
            lo, hi = v["range"]
            step = v.get("step", (hi - lo) / 500)
            param_info[n] = [lo, hi, v["default"], f"Parameter {n}"]
        self.param_info = param_info
        state_names = list(spec.get("state_vars", {}).keys())
        self.state_names = state_names
        # Sync display indices
        display = spec.get("display", [0, min(1, len(state_names) - 1)])
        self.display = display
        # Set default display limits from state variable ranges
        state_vars = spec.get("state_vars", {})
        if state_names:
            first = state_names[0]
            lo, hi = state_vars[first]["range"]
            self.xlim = [lo, hi]
            self.x0 = state_vars[first]["default"]
        if len(state_names) > 1:
            second = state_names[display[1]] if len(display) > 1 else state_names[1]
            lo, hi = state_vars[second]["range"]
            self.ylim = [lo, hi]
            self.y0 = state_vars[second]["default"]
        # Initialize clamped values for non-displayed vars
        n = len(state_names)
        clamped = []
        for i, name in enumerate(state_names):
            if i in display:
                clamped.append(None)  # displayed vars are not clamped
            else:
                lo, hi = state_vars[name]["range"]
                clamped.append((lo + hi) / 2.0)
        self.clamped = clamped

    # ── Python-side helpers (for programmatic use / validation) ──

    def run_sweep(self, param_name: str, values: list):
        """Run a parameter sweep from Python (delegates to JS in the widget).

        Parameters
        ----------
        param_name : str
            Parameter to vary.
        values : list of float
            Values to evaluate.
        """
        # The JS front-end handles sweeps interactively.  This method is a
        # convenience for programmatic access; it simply ensures the sweep
        # traitlets are in a consistent state.  For actual computation in
        # a headless environment use the model classes in ``models.py``.
        pass  # sweeps are computed client-side by the JS front-end

    # ── Standalone HTML export ──

    def to_standalone_html(
        self,
        filename: str | pathlib.Path,
        title: str = "Phase Plane Widget",
        *,
        on_render_js: str = "",
    ):
        """Export the widget to a self-contained HTML file.

        The resulting ``.html`` file contains the full JS computation engine,
        all model definitions, the CSS, and the current widget state.  It works
        in any modern browser with **no Python runtime and no Jupyter kernel**.

        Parameters
        ----------
        filename : str or pathlib.Path
            Output path (e.g. ``"widget.html"``).
        title : str
            Page ``<title>``.
        on_render_js : str
            Optional JavaScript snippet executed after the widget renders.
            Useful for auto-opening UI panels (e.g. the live editor).
        """
        js_code = self._esm
        css_code = self._css

        state = {
            "model_name": self.model_name,
            "params": self.params,
            "param_info": self.param_info,
            "state_names": self.state_names,
            "x0": self.x0,
            "y0": self.y0,
            "xlim": self.xlim,
            "ylim": self.ylim,
            "t_max": self.t_max,
            "show_nullclines": self.show_nullclines,
            "show_vector_field": self.show_vector_field,
            "show_trajectory": self.show_trajectory,
            "show_fixed_points": self.show_fixed_points,
            "nullcline_x": self.nullcline_x,
            "nullcline_y": self.nullcline_y,
            "vector_field": self.vector_field,
            "fixed_points": self.fixed_points,
            "trajectory": self.trajectory,
            "sweep_results": self.sweep_results,
            "sweep_fixed_points": self.sweep_fixed_points,
            "sweep_param": "",
            "sweep_running": False,
            "model_spec": self.model_spec,
            "display": list(self.display) if self.display else [0, 1],
            "clamped": list(self.clamped) if self.clamped else None,
            "integrator": self.integrator,
            "noise_enable": self.noise_enable,
            "noise_sigma": self.noise_sigma,
            "display_mode": self.display_mode,
        }

        extra_js = f"\n{on_render_js}\n" if on_render_js else ""

        html = f"""<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>{title}</title>
    <style>
{css_code}
    </style>
</head>
<body>
<div id="ppw-root"></div>
<script type="module">
{js_code}

const initialState = {json.dumps(state, indent=2)};

const mockModel = {{
  _isMock: true,
  _data: initialState,
  _callbacks: {{}},
  get(name) {{ return this._data[name]; }},
  set(name, value) {{ this._data[name] = value; }},
  save_changes() {{}},
  on(event, cb) {{
    if (!this._callbacks[event]) this._callbacks[event] = [];
    this._callbacks[event].push(cb);
  }},
  send() {{}},
}};

render({{ model: mockModel, el: document.getElementById('ppw-root') }});
{extra_js}
</script>
</body>
</html>"""

        pathlib.Path(filename).write_text(html, encoding="utf-8")

    # ── TikZ / PGFPlots export ──

    STABILITY_MARKERS = {
        "stable_node": ("circle", "green!70!black", "green!70!black"),
        "stable_focus": ("circle", "green!70!black", "green!70!black"),
        "unstable_node": ("circle", "red!70!black", "white"),
        "unstable_focus": ("circle", "red!70!black", "white"),
        "saddle": ("diamond", "purple", "purple"),
    }

    def _compute_vector_field_for_tikz(self, nx=15, ny=15):
        r"""Compute a normalized vector-field grid for embedding in TikZ.

        Returns a list of arrow endpoints ``[x1, y1, x2, y2]`` suitable for
        ``\draw[-stealth] (axis cs:x1,y1) -- (axis cs:x2,y2);``.
        """
        model = self._get_model()
        x = np.linspace(self.xlim[0], self.xlim[1], nx)
        y = np.linspace(self.ylim[0], self.ylim[1], ny)
        dx = self.xlim[1] - self.xlim[0]
        dy = self.ylim[1] - self.ylim[0]
        scale = 0.03 * min(dx, dy)
        arrows = []
        for xi in x:
            for yi in y:
                state = [0.0] * model.dim
                if self.display and len(self.display) >= 2:
                    state[self.display[0]] = xi
                    state[self.display[1]] = yi
                else:
                    state[0] = xi
                    if model.dim > 1:
                        state[1] = yi
                if self.clamped:
                    for i, val in enumerate(self.clamped):
                        if val is not None and i < model.dim:
                            state[i] = val
                d = model.f(0, state, self.params)
                norm = np.sqrt(d[0] ** 2 + d[1] ** 2)
                if norm > 1e-12:
                    x2 = xi + scale * d[0] / norm
                    y2 = yi + scale * d[1] / norm
                    arrows.append([float(xi), float(yi), float(x2), float(y2)])
        return arrows

    @staticmethod
    def _fmt_coords(data):
        """Format a list of [x, y] pairs as PGFPlots ``coordinates`` block."""
        if not data:
            return "coordinates {(0,0)}"
        pairs = " ".join(f"({row[0]:.6f},{row[1]:.6f})" for row in data)
        return f"coordinates {{\n  {pairs}\n}}"

    def export_tikz(self, filename: str | pathlib.Path = "phase_plane.tex"):
        """Generate a self-contained ``.tex`` file.

        The resulting file can be compiled with ``pdflatex`` or ``lualatex``
        (requires the ``pgfplots`` package).

        Parameters
        ----------
        filename : str or pathlib.Path
            Output path for the ``.tex`` file.
        """
        filename = pathlib.Path(filename)

        x_label = (
            self.state_names[self.display[0]]
            if self.display and len(self.state_names) > self.display[0]
            else "x"
        )
        y_label = (
            self.state_names[self.display[1]]
            if self.display and len(self.state_names) > self.display[1]
            else "y"
        )
        xlim = self.xlim
        ylim = self.ylim

        # ── vector field ──
        vfield_data = []
        if self.show_vector_field:
            vfield_data = self._compute_vector_field_for_tikz()

        # ── trajectories ──
        traj_data = []
        if self.show_trajectory and self.trajectory:
            idx_x = 1 + self.display[0]
            idx_y = 1 + self.display[1]
            traj_data = [
                [float(row[idx_x]), float(row[idx_y])] for row in self.trajectory
            ]

        # ── nullclines ──
        nc_x = self.nullcline_x if self.show_nullclines else []
        nc_y = self.nullcline_y if self.show_nullclines else []

        # ── fixed points ──
        fps = self.fixed_points if self.show_fixed_points else []

        # ── build plot commands ──
        plots = []

        if vfield_data:
            lines = "\n".join(
                f"            \\draw[-stealth, gray] (axis cs:{x1:.6f},{y1:.6f}) -- (axis cs:{x2:.6f},{y2:.6f});"
                for x1, y1, x2, y2 in vfield_data
            )
            plots.append(lines)

        if nc_x:
            plots.append(
                f"            \\addplot[blue, thick, no marks, smooth] {self._fmt_coords(nc_x)};"
            )

        if nc_y:
            plots.append(
                f"            \\addplot[red, thick, no marks, smooth] {self._fmt_coords(nc_y)};"
            )

        if traj_data:
            plots.append(
                f"            \\addplot[green!60!black, thick, no marks, smooth] {self._fmt_coords(traj_data)};"
            )

        plots_block = "\n".join(plots)

        # ── fixed-point nodes ──
        fp_nodes = []
        for fp in fps:
            if len(fp) < 3:
                continue
            x_fp, y_fp, stability = float(fp[0]), float(fp[1]), fp[2]
            shape, color, fill = self.STABILITY_MARKERS.get(
                stability, ("circle", "black", "white")
            )
            inner = ""
            if stability == "stable_focus":
                inner = (
                    f"\\node[fill={color}, circle, inner sep=0.8pt] "
                    f"at (axis cs:{x_fp:.6f},{y_fp:.6f}) {{}};"
                )
            elif stability == "unstable_focus":
                inner = (
                    f"\\node[draw={color}, fill={color}, circle, inner sep=0.8pt] "
                    f"at (axis cs:{x_fp:.6f},{y_fp:.6f}) {{}};"
                )
            opts = f"draw={color}"
            if fill == "white":
                opts += ", fill=white"
            elif fill != color:
                opts += f", fill={fill}"
            node_tex = (
                f"\\node[{opts}, {shape}, inner sep=1.5pt] "
                f"at (axis cs:{x_fp:.6f},{y_fp:.6f}) {{}};"
            )
            if inner:
                node_tex += "\n            " + inner
            fp_nodes.append(node_tex)
        fp_block = "\n            ".join(fp_nodes)

        # ── parameter annotation ──
        param_lines = ", ".join(f"{k.replace('_', '\\_')}={v:.4g}" for k, v in self.params.items())
        param_node = (
            r"\path (rel axis cs:0.02,0.98) node[anchor=north west, font=\tiny, align=left] "
            f"{{{self.model_name.replace('_', '\\_')}\\\\ {param_lines}}};"
        )

        tex = (
            r"\documentclass[border=5pt]{standalone}" + "\n"
            r"\usepackage{pgfplots}" + "\n"
            r"\usepackage{tikz}" + "\n"
            r"\pgfplotsset{compat=1.17}" + "\n"
            r"\usetikzlibrary{shapes.geometric}" + "\n"
            r"\begin{document}" + "\n"
            r"\begin{tikzpicture}" + "\n"
            r"\begin{axis}[" + "\n"
            "    width=10cm, height=10cm,\n"
            f"    xmin={xlim[0]}, xmax={xlim[1]}, ymin={ylim[0]}, ymax={ylim[1]},\n"
            f"    xlabel=${x_label}$, ylabel=${y_label}$,\n"
            "    axis lines=middle,\n"
            "    enlargelimits=true,\n"
            "]\n"
            + plots_block + "\n"
            + (fp_block + "\n" if fp_block else "")
            + "            " + param_node + "\n"
            r"\end{axis}" + "\n"
            r"\end{tikzpicture}" + "\n"
            r"\end{document}" + "\n"
        )
        filename.write_text(tex, encoding="utf-8")
        return str(filename)

    # ── Custom message handler (JS → Python) ──

    def _on_custom_msg(self, _widget, content, buffers):
        """Handle messages from the JS front-end."""
        msg_type = content.get("type")
        if msg_type == "export_tikz":
            fd, tmppath = tempfile.mkstemp(suffix=".tex", prefix="phase_plane_")
            os.close(fd)
            try:
                self.export_tikz(tmppath)
                tex_content = pathlib.Path(tmppath).read_text(encoding="utf-8")
                self.send(
                    {
                        "type": "tikz_data",
                        "content": tex_content,
                        "filename": "phase_plane.tex",
                    }
                )
            finally:
                os.unlink(tmppath)

Functions

__init__

__init__(model=None, **kwargs)
Source code in src/tvb_phaseplane/widget.py
def __init__(self, model=None, **kwargs):
    if model is not None:
        self._model_instance = model
        kwargs.setdefault("model_name", model.name)
    super().__init__(**kwargs)
    self._update_model()
    # Register handler for custom JS → Python messages (e.g. TikZ export)
    self.on_msg(self._on_custom_msg)

to_standalone_html

to_standalone_html(filename: str | Path, title: str = 'Phase Plane Widget', *, on_render_js: str = '')

Export the widget to a self-contained HTML file.

The resulting .html file contains the full JS computation engine, all model definitions, the CSS, and the current widget state. It works in any modern browser with no Python runtime and no Jupyter kernel.

Parameters

filename : str or pathlib.Path Output path (e.g. "widget.html"). title : str Page <title>. on_render_js : str Optional JavaScript snippet executed after the widget renders. Useful for auto-opening UI panels (e.g. the live editor).

Source code in src/tvb_phaseplane/widget.py
    def to_standalone_html(
        self,
        filename: str | pathlib.Path,
        title: str = "Phase Plane Widget",
        *,
        on_render_js: str = "",
    ):
        """Export the widget to a self-contained HTML file.

        The resulting ``.html`` file contains the full JS computation engine,
        all model definitions, the CSS, and the current widget state.  It works
        in any modern browser with **no Python runtime and no Jupyter kernel**.

        Parameters
        ----------
        filename : str or pathlib.Path
            Output path (e.g. ``"widget.html"``).
        title : str
            Page ``<title>``.
        on_render_js : str
            Optional JavaScript snippet executed after the widget renders.
            Useful for auto-opening UI panels (e.g. the live editor).
        """
        js_code = self._esm
        css_code = self._css

        state = {
            "model_name": self.model_name,
            "params": self.params,
            "param_info": self.param_info,
            "state_names": self.state_names,
            "x0": self.x0,
            "y0": self.y0,
            "xlim": self.xlim,
            "ylim": self.ylim,
            "t_max": self.t_max,
            "show_nullclines": self.show_nullclines,
            "show_vector_field": self.show_vector_field,
            "show_trajectory": self.show_trajectory,
            "show_fixed_points": self.show_fixed_points,
            "nullcline_x": self.nullcline_x,
            "nullcline_y": self.nullcline_y,
            "vector_field": self.vector_field,
            "fixed_points": self.fixed_points,
            "trajectory": self.trajectory,
            "sweep_results": self.sweep_results,
            "sweep_fixed_points": self.sweep_fixed_points,
            "sweep_param": "",
            "sweep_running": False,
            "model_spec": self.model_spec,
            "display": list(self.display) if self.display else [0, 1],
            "clamped": list(self.clamped) if self.clamped else None,
            "integrator": self.integrator,
            "noise_enable": self.noise_enable,
            "noise_sigma": self.noise_sigma,
            "display_mode": self.display_mode,
        }

        extra_js = f"\n{on_render_js}\n" if on_render_js else ""

        html = f"""<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>{title}</title>
    <style>
{css_code}
    </style>
</head>
<body>
<div id="ppw-root"></div>
<script type="module">
{js_code}

const initialState = {json.dumps(state, indent=2)};

const mockModel = {{
  _isMock: true,
  _data: initialState,
  _callbacks: {{}},
  get(name) {{ return this._data[name]; }},
  set(name, value) {{ this._data[name] = value; }},
  save_changes() {{}},
  on(event, cb) {{
    if (!this._callbacks[event]) this._callbacks[event] = [];
    this._callbacks[event].push(cb);
  }},
  send() {{}},
}};

render({{ model: mockModel, el: document.getElementById('ppw-root') }});
{extra_js}
</script>
</body>
</html>"""

        pathlib.Path(filename).write_text(html, encoding="utf-8")

run_sweep

run_sweep(param_name: str, values: list)

Run a parameter sweep from Python (delegates to JS in the widget).

Parameters

param_name : str Parameter to vary. values : list of float Values to evaluate.

Source code in src/tvb_phaseplane/widget.py
def run_sweep(self, param_name: str, values: list):
    """Run a parameter sweep from Python (delegates to JS in the widget).

    Parameters
    ----------
    param_name : str
        Parameter to vary.
    values : list of float
        Values to evaluate.
    """
    # The JS front-end handles sweeps interactively.  This method is a
    # convenience for programmatic access; it simply ensures the sweep
    # traitlets are in a consistent state.  For actual computation in
    # a headless environment use the model classes in ``models.py``.
    pass  # sweeps are computed client-side by the JS front-end

Model Base Class

BaseModel

Base class for neural mass models.

Source code in src/tvb_phaseplane/models.py
class BaseModel:
    """Base class for neural mass models."""

    name = "base"
    dim = 2
    state_names = ["x", "y"]
    default_params = {}
    param_info = {}
    default_xlim = [-1.0, 1.0]
    default_ylim = [-1.0, 1.0]

    @property
    def state_vars(self):
        """Ordered list of all state variable names (alias for ``state_names``)."""
        return list(self.state_names)

    def to_model_spec(self):
        """Return a dict compatible with :class:`ModelSpec` and the JS widget.

        Built-in models use this to populate the ``model_spec`` traitlet when
        the user switches to a custom-model view.
        """
        from .model_spec import ModelSpec

        # Build state_vars dict: name -> (default, (min, max))
        sv = {}
        for i, name in enumerate(self.state_names):
            # Use default xlim/ylim for ranges
            if i == 0:
                lo, hi = self.default_xlim
            elif i == 1:
                lo, hi = self.default_ylim
            else:
                lo, hi = -5.0, 5.0
            sv[name] = (0.0, (lo, hi))

        # Build parameters dict: name -> (default, (min, max), step)
        params = {}
        for name, spec in self.param_info.items():
            lo, hi, default, _desc = spec
            step = (hi - lo) / 500.0
            params[name] = (default, (lo, hi), step)

        # Equations as strings — for built-in models the JS already has the
        # compiled functions, so we return empty equations and rely on the
        # ``model_name`` traitlet to select the hard-coded JS model.
        return {
            "name": self.name,
            "state_vars": {
                n: {"default": d, "range": [lo, hi]}
                for n, (d, (lo, hi)) in sv.items()
            },
            "parameters": {
                n: {"default": d, "range": [lo, hi], "step": step}
                for n, (d, (lo, hi), step) in params.items()
            },
            "equations": {},
            "display": list(range(min(2, len(self.state_names)))),
            "custom_functions": {},
            "integrator": "rk4",
            "noise_per_var": None,
        }

    def f(self, t, state, params):
        """Compute derivatives. Returns list of length dim."""
        raise NotImplementedError

    def compute_vector_field(self, params, xlim, ylim, n_grid=12):
        """Compute a sparse grid of derivative vectors."""
        x = np.linspace(xlim[0], xlim[1], n_grid)
        y = np.linspace(ylim[0], ylim[1], n_grid)
        vectors = []
        for xi in x:
            for yi in y:
                d = self.f(0, [xi, yi], params)
                vectors.append([float(xi), float(yi), float(d[0]), float(d[1])])
        return vectors

    def compute_nullclines(self, params, xlim, ylim, n_grid=60):
        """Compute nullclines by finding zero crossings on a grid."""
        x = np.linspace(xlim[0], xlim[1], n_grid)
        y = np.linspace(ylim[0], ylim[1], n_grid)
        X, Y = np.meshgrid(x, y)

        dx = np.zeros_like(X)
        dy = np.zeros_like(Y)
        for i in range(n_grid):
            for j in range(n_grid):
                d = self.f(0, [X[i, j], Y[i, j]], params)
                dx[i, j] = d[0]
                dy[i, j] = d[1]

        nc_x = self._find_zero_crossings(X, Y, dx)
        nc_y = self._find_zero_crossings(X, Y, dy)
        return nc_x, nc_y

    def _find_zero_crossings(self, X, Y, Z):
        """Find points where Z=0 using linear interpolation."""
        points = []
        n = X.shape[0]

        for i in range(n):
            for j in range(n - 1):
                if Z[i, j] == 0:
                    points.append([float(X[i, j]), float(Y[i, j])])
                elif Z[i, j] * Z[i, j + 1] < 0:
                    t = abs(Z[i, j]) / (abs(Z[i, j]) + abs(Z[i, j + 1]))
                    px = X[i, j] + t * (X[i, j + 1] - X[i, j])
                    py = Y[i, j]
                    points.append([float(px), float(py)])

        for i in range(n - 1):
            for j in range(n):
                if Z[i, j] == 0:
                    points.append([float(X[i, j]), float(Y[i, j])])
                elif Z[i, j] * Z[i + 1, j] < 0:
                    t = abs(Z[i, j]) / (abs(Z[i, j]) + abs(Z[i + 1, j]))
                    px = X[i, j]
                    py = Y[i, j] + t * (Y[i + 1, j] - Y[i, j])
                    points.append([float(px), float(py)])

        return points

    def find_fixed_points(self, params, xlim, ylim, n_grid=25):
        """Find fixed points by grid search + numerical refinement."""
        x = np.linspace(xlim[0], xlim[1], n_grid)
        y = np.linspace(ylim[0], ylim[1], n_grid)
        fixed_points = []
        tol = 0.08

        for xi in x:
            for yi in y:
                try:
                    sol = fsolve(lambda s: self.f(0, s, params), [xi, yi], full_output=True)
                    if sol[2] == 1:
                        fp = sol[0]
                        if (xlim[0] - 0.5 <= fp[0] <= xlim[1] + 0.5 and
                                ylim[0] - 0.5 <= fp[1] <= ylim[1] + 0.5):
                            # Verify it's actually a fixed point
                            residual = np.linalg.norm(self.f(0, fp, params))
                            if residual > 0.1:
                                continue
                            # Check for duplicates
                            is_new = True
                            for existing in fixed_points:
                                if np.linalg.norm(np.array(fp) - np.array(existing[:2])) < tol:
                                    is_new = False
                                    break
                            if is_new:
                                J = self.jacobian(fp, params)
                                ev = np.linalg.eigvals(J)
                                stability = self._classify_fixed_point(ev)
                                fixed_points.append([float(fp[0]), float(fp[1]), stability])
                except Exception:
                    pass

        return fixed_points

    def jacobian(self, state, params, eps=1e-6):
        """Numerical Jacobian."""
        n = len(state)
        J = np.zeros((n, n))
        f0 = np.array(self.f(0, state, params))
        for i in range(n):
            s_plus = np.array(state, dtype=float)
            s_plus[i] += eps
            f_plus = np.array(self.f(0, s_plus.tolist(), params))
            J[:, i] = (f_plus - f0) / eps
        return J

    def _classify_fixed_point(self, eigenvalues):
        """Classify fixed point based on eigenvalues."""
        real = np.real(eigenvalues)
        imag = np.imag(eigenvalues)

        if all(r < -1e-6 for r in real):
            return "stable_focus" if any(abs(im) > 1e-6 for im in imag) else "stable_node"
        elif all(r > 1e-6 for r in real):
            return "unstable_focus" if any(abs(im) > 1e-6 for im in imag) else "unstable_node"
        return "saddle"

    def compute_trajectory(self, initial_state, params, t_span, dt=0.01):
        """Compute trajectory using solve_ivp."""
        try:
            t_eval = np.arange(t_span[0], t_span[1], dt)
            sol = solve_ivp(
                lambda t, y: self.f(t, y, params),
                [t_span[0], t_span[1]],
                initial_state,
                method="RK45",
                t_eval=t_eval,
                max_step=dt * 5,
            )
            trajectory = []
            for i in range(len(sol.t)):
                row = [float(sol.t[i])]
                for j in range(self.dim):
                    row.append(float(sol.y[j, i]))
                trajectory.append(row)
            return trajectory
        except Exception:
            return []

    def detect_regime(self, params, xlim, ylim, t_total=120.0, dt=0.05):
        """Detect dynamical regime: fixed_point, limit_cycle, or other."""
        ics = [
            [xlim[0] * 0.6, ylim[0] * 0.6],
            [xlim[1] * 0.6, ylim[1] * 0.6],
            [(xlim[0] + xlim[1]) * 0.5, (ylim[0] + ylim[1]) * 0.5],
            [xlim[0] + 0.1 * (xlim[1] - xlim[0]), ylim[0] + 0.1 * (ylim[1] - ylim[0])],
        ]

        regimes = []
        for ic in ics:
            try:
                traj = self.compute_trajectory(ic, params, [0, t_total], dt=dt)
                if not traj:
                    regimes.append("other")
                    continue

                n = len(traj)
                n_check = min(200, n // 4)
                if n_check < 10:
                    regimes.append("other")
                    continue

                last = np.array(traj[-n_check:])
                dx = np.std(last[:, 1])
                dy = np.std(last[:, 2])

                if dx < 0.025 and dy < 0.025:
                    regimes.append("fixed_point")
                else:
                    # Check amplitude stability for limit cycle
                    mid = len(last) // 2
                    x_vals = last[:, 1]
                    y_vals = last[:, 2]
                    amp1_x = np.max(x_vals[:mid]) - np.min(x_vals[:mid])
                    amp2_x = np.max(x_vals[mid:]) - np.min(x_vals[mid:])
                    amp1_y = np.max(y_vals[:mid]) - np.min(y_vals[:mid])
                    amp2_y = np.max(y_vals[mid:]) - np.min(y_vals[mid:])

                    x_stable = abs(amp1_x - amp2_x) < 0.15 * max(amp1_x, 0.01)
                    y_stable = abs(amp1_y - amp2_y) < 0.15 * max(amp1_y, 0.01)

                    if x_stable and y_stable:
                        regimes.append("limit_cycle")
                    else:
                        regimes.append("other")
            except Exception:
                regimes.append("other")

        return Counter(regimes).most_common(1)[0][0]

Functions

f

f(t, state, params)

Compute derivatives. Returns list of length dim.

Source code in src/tvb_phaseplane/models.py
def f(self, t, state, params):
    """Compute derivatives. Returns list of length dim."""
    raise NotImplementedError

compute_trajectory

compute_trajectory(initial_state, params, t_span, dt=0.01)

Compute trajectory using solve_ivp.

Source code in src/tvb_phaseplane/models.py
def compute_trajectory(self, initial_state, params, t_span, dt=0.01):
    """Compute trajectory using solve_ivp."""
    try:
        t_eval = np.arange(t_span[0], t_span[1], dt)
        sol = solve_ivp(
            lambda t, y: self.f(t, y, params),
            [t_span[0], t_span[1]],
            initial_state,
            method="RK45",
            t_eval=t_eval,
            max_step=dt * 5,
        )
        trajectory = []
        for i in range(len(sol.t)):
            row = [float(sol.t[i])]
            for j in range(self.dim):
                row.append(float(sol.y[j, i]))
            trajectory.append(row)
        return trajectory
    except Exception:
        return []

detect_regime

detect_regime(params, xlim, ylim, t_total=120.0, dt=0.05)

Detect dynamical regime: fixed_point, limit_cycle, or other.

Source code in src/tvb_phaseplane/models.py
def detect_regime(self, params, xlim, ylim, t_total=120.0, dt=0.05):
    """Detect dynamical regime: fixed_point, limit_cycle, or other."""
    ics = [
        [xlim[0] * 0.6, ylim[0] * 0.6],
        [xlim[1] * 0.6, ylim[1] * 0.6],
        [(xlim[0] + xlim[1]) * 0.5, (ylim[0] + ylim[1]) * 0.5],
        [xlim[0] + 0.1 * (xlim[1] - xlim[0]), ylim[0] + 0.1 * (ylim[1] - ylim[0])],
    ]

    regimes = []
    for ic in ics:
        try:
            traj = self.compute_trajectory(ic, params, [0, t_total], dt=dt)
            if not traj:
                regimes.append("other")
                continue

            n = len(traj)
            n_check = min(200, n // 4)
            if n_check < 10:
                regimes.append("other")
                continue

            last = np.array(traj[-n_check:])
            dx = np.std(last[:, 1])
            dy = np.std(last[:, 2])

            if dx < 0.025 and dy < 0.025:
                regimes.append("fixed_point")
            else:
                # Check amplitude stability for limit cycle
                mid = len(last) // 2
                x_vals = last[:, 1]
                y_vals = last[:, 2]
                amp1_x = np.max(x_vals[:mid]) - np.min(x_vals[:mid])
                amp2_x = np.max(x_vals[mid:]) - np.min(x_vals[mid:])
                amp1_y = np.max(y_vals[:mid]) - np.min(y_vals[:mid])
                amp2_y = np.max(y_vals[mid:]) - np.min(y_vals[mid:])

                x_stable = abs(amp1_x - amp2_x) < 0.15 * max(amp1_x, 0.01)
                y_stable = abs(amp1_y - amp2_y) < 0.15 * max(amp1_y, 0.01)

                if x_stable and y_stable:
                    regimes.append("limit_cycle")
                else:
                    regimes.append("other")
        except Exception:
            regimes.append("other")

    return Counter(regimes).most_common(1)[0][0]

Wilson-Cowan

WilsonCowan

Bases: BaseModel

Wilson-Cowan model of excitatory and inhibitory neural populations.

Source code in src/tvb_phaseplane/models.py
class WilsonCowan(BaseModel):
    """Wilson-Cowan model of excitatory and inhibitory neural populations."""

    name = "wilson_cowan"
    dim = 2
    state_names = ["E", "I"]
    default_params = {
        "aee": 10.0,
        "aei": 10.0,
        "aie": 10.0,
        "aii": 2.0,
        "Pe": -2.0,
        "Pi": -8.0,
        "ke": 1.0,
        "ki": 1.0,
        "thetae": 4.0,
        "thetai": 4.0,
    }
    param_info = {
        "aee": (0.0, 20.0, 10.0, "E→E coupling"),
        "aei": (0.0, 20.0, 10.0, "I→E coupling"),
        "aie": (0.0, 20.0, 10.0, "E→I coupling"),
        "aii": (0.0, 20.0, 2.0, "I→I coupling"),
        "Pe": (-10.0, 10.0, -2.0, "External E input"),
        "Pi": (-10.0, 10.0, -8.0, "External I input"),
        "ke": (0.1, 5.0, 1.0, "E sigmoid gain"),
        "ki": (0.1, 5.0, 1.0, "I sigmoid gain"),
        "thetae": (0.0, 10.0, 4.0, "E sigmoid threshold"),
        "thetai": (0.0, 10.0, 4.0, "I sigmoid threshold"),
    }
    default_xlim = [-0.2, 1.2]
    default_ylim = [-0.2, 1.2]

    def f(self, t, state, params):
        E, I = state
        p = {**self.default_params, **params}

        def _sigmoid(x, k, theta):
            arg = -k * (x - theta)
            # Clip to prevent overflow in exp (ln(max_float64) ~ 709)
            arg = np.clip(arg, -709, 709)
            return 1.0 / (1.0 + np.exp(arg))

        return [
            -E + _sigmoid(p["aee"] * E - p["aei"] * I + p["Pe"], p["ke"], p["thetae"]),
            -I + _sigmoid(p["aie"] * E - p["aii"] * I + p["Pi"], p["ki"], p["thetai"]),
        ]

FitzHugh-Nagumo

FitzHughNagumo

Bases: BaseModel

FitzHugh-Nagumo model of excitable neuron dynamics.

Source code in src/tvb_phaseplane/models.py
class FitzHughNagumo(BaseModel):
    """FitzHugh-Nagumo model of excitable neuron dynamics."""

    name = "fitzhugh_nagumo"
    dim = 2
    state_names = ["v", "w"]
    default_params = {"a": 0.7, "b": 0.8, "epsilon": 0.08, "I": 0.5}
    param_info = {
        "a": (-1.0, 2.0, 0.7, "Recovery offset"),
        "b": (0.0, 2.0, 0.8, "Recovery gain"),
        "epsilon": (0.001, 1.0, 0.08, "Time scale (ε)"),
        "I": (-2.0, 2.0, 0.5, "External current"),
    }
    default_xlim = [-3.0, 3.0]
    default_ylim = [-1.5, 2.0]

    def f(self, t, state, params):
        v, w = state
        p = {**self.default_params, **params}
        return [
            v - v ** 3 / 3.0 - w + p["I"],
            p["epsilon"] * (v + p["a"] - p["b"] * w),
        ]

MPR (Quadratic Integrate-and-Fire)

MPRModel

Bases: BaseModel

Montbrió-Pazó-Roxin exact firing-rate equations for QIF neurons.

The macroscopic variables are firing rate (r) and mean membrane potential (v). See: Phys. Rev. X 5, 021028 (2015).

Source code in src/tvb_phaseplane/models.py
class MPRModel(BaseModel):
    """Montbrió-Pazó-Roxin exact firing-rate equations for QIF neurons.

    The macroscopic variables are firing rate (r) and mean membrane potential (v).
    See: Phys. Rev. X 5, 021028 (2015).
    """

    name = "mpr"
    dim = 2
    state_names = ["r", "v"]
    default_params = {
        "delta": 1.0,
        "eta_bar": -5.0,
        "J": 15.0,
        "I": 0.0,
    }
    param_info = {
        "delta": (0.01, 5.0, 1.0, "Lorentzian half-width Δ"),
        "eta_bar": (-20.0, 10.0, -5.0, "Mean excitability η̄"),
        "J": (-20.0, 30.0, 15.0, "Synaptic coupling J"),
        "I": (-10.0, 10.0, 0.0, "External input I"),
    }
    default_xlim = [0.0, 2.0]
    default_ylim = [-4.0, 2.0]

    def f(self, t, state, params):
        r, v = state
        p = {**self.default_params, **params}
        # Clip r to avoid numerical issues at r ≈ 0
        r_eff = max(r, 1e-10)
        dr = p["delta"] / np.pi + 2 * r_eff * v
        dv = v**2 + p["eta_bar"] + p["J"] * r_eff + p["I"] - (np.pi * r_eff) ** 2
        return [dr, dv]