Two Moons Classification

This example demonstrates how to use the Neuralk NICLClassifier on the classic two moons dataset - a simple binary classification task with a non-linear decision boundary.

Note

For this example to run, the environment variable API_KEY must be set with your Neuralk API key.

Generate the two moons dataset

We use the neuralk.datasets module to generate the two moons data.

import os

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

from neuralk import NICLClassifier
from neuralk.datasets import two_moons

# Load the two moons dataset
moons_data = two_moons()
df = pl.read_csv(moons_data["path"])

X = df.drop("label").to_numpy().astype(np.float32)
y = df["label"].to_numpy()

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

print(f"{X_train.shape=} {y_train.shape=} {X_test.shape=} {y_test.shape=}")
X_train.shape=(400, 2) y_train.shape=(400,) X_test.shape=(100, 2) y_test.shape=(100,)

Fit and predict with the Neuralk NICLClassifier

The NICLClassifier uses Neuralk’s In-Context Learning model. Note that no long-running training is happening - the model is pretrained and uses the training data as context for predictions.

classifier = NICLClassifier(api_key=os.environ["API_KEY"])
classifier.fit(X_train, y_train)

y_pred = classifier.predict(X_test)

acc = accuracy_score(y_test, y_pred)
print(f"Accuracy: {acc:.3f}")

# Check API response details
print(f"Credits consumed: {classifier.credits_consumed}")
print(f"Latency: {classifier.latency_ms}ms")
Traceback (most recent call last):
  File "/home/runner/work/neuralk/neuralk/examples/0030_two_moon_classification.py", line 64, in <module>
    y_pred = classifier.predict(X_test)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/neuralk/neuralk/src/neuralk/_base_classifier.py", line 149, in predict
    result = self._remote_predict(X)
             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/neuralk/neuralk/src/neuralk/_classifier.py", line 182, in _remote_predict
    response = self._client.classifications.create(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/neuralk/neuralk/src/neuralk/_api.py", line 228, in create
    return self._client._make_request(tar_bytes, headers)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/neuralk/neuralk/src/neuralk/_api.py", line 601, in _make_request
    self._raise_for_status(response)
  File "/home/runner/work/neuralk/neuralk/src/neuralk/_api.py", line 722, in _raise_for_status
    raise NeuralkException(message, HTTPStatus(status), response.text)
neuralk.exceptions.NeuralkException: ('{"detail":{"error":{"code":4030107,"type":"ORG_EXPIRED","message":"Organization trial has expired. Please contact sales.","request_id":"040486b7-cf44-434b-9907-867f07c726d2"}}}', <HTTPStatus.FORBIDDEN: 403>, '{"detail":{"error":{"code":4030107,"type":"ORG_EXPIRED","message":"Organization trial has expired. Please contact sales.","request_id":"040486b7-cf44-434b-9907-867f07c726d2"}}}')

Visualize the results

We plot the ground truth labels alongside the model predictions.

plt.rcParams.update(
    {
        "axes.edgecolor": "#4d4d4d",
        "axes.linewidth": 1.2,
        "axes.facecolor": "#f5f5f5",
        "figure.facecolor": "white",
    }
)

fig, axes = plt.subplots(1, 2, figsize=(11, 5), dpi=120)
titles = ["Ground Truth", f"Model Prediction\nAccuracy: {acc:.2f}"]
colors = ["#1a73e8", "#ffa600"]  # Blue & orange

for idx, ax in enumerate(axes):
    labels = y_test if idx == 0 else y_pred
    for lab in np.unique(labels):
        ax.scatter(
            X_test[labels == lab, 0],
            X_test[labels == lab, 1],
            s=70,
            marker="o",
            c=colors[int(lab)],
            edgecolors="white",
            linewidths=0.8,
            alpha=0.9,
            label=f"Class {lab}" if idx == 0 else None,
            zorder=3,
        )

    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect("equal")
    ax.set_title(titles[idx], fontsize=14, weight="bold", pad=12)
    ax.grid(False)

    x_margin = 0.4
    y_margin = 0.4
    ax.set_xlim(X_test[:, 0].min() - x_margin, X_test[:, 0].max() + x_margin)
    ax.set_ylim(X_test[:, 1].min() - y_margin, X_test[:, 1].max() + y_margin)

    ax.text(
        0.05,
        0.98,
        chr(ord("A") + idx),
        transform=ax.transAxes,
        fontsize=16,
        fontweight="bold",
        va="top",
        ha="right",
    )

handles, labels_ = axes[0].get_legend_handles_labels()
fig.legend(
    handles,
    labels_,
    loc="lower center",
    ncol=2,
    frameon=False,
    fontsize=12,
    bbox_to_anchor=(0.5, 0.02),
)

fig.tight_layout()
plt.subplots_adjust(bottom=0.05)
plt.show()

Total running time of the script: (0 minutes 0.467 seconds)

Gallery generated by Sphinx-Gallery