5 Commits

Author SHA1 Message Date
Palash Tyagi
500f5c371d Merge 0a9d9e38c3 into 4cb095b71f 2025-05-03 19:04:41 +01:00
0a9d9e38c3 Merge branch 'main' into floatops 2025-05-03 01:32:25 +01:00
4d846287e8 Merge branch 'main' into floatops 2025-05-02 23:38:46 +01:00
bdd3293c65 Merge branch 'main' into floatops 2025-05-01 01:08:16 +01:00
Palash Tyagi
bda1298397 empty commit 2025-05-01 01:05:37 +01:00
50 changed files with 3806 additions and 6632 deletions

View File

@@ -1,73 +0,0 @@
# actions/runner-fallback/action.yml
name: "Runner Fallback"
description: |
Chooses a self-hosted runner when one with the required labels is online,
otherwise returns a fallback GitHub-hosted label.
inputs:
primary-runner:
description: 'Comma-separated label list for the preferred self-hosted runner (e.g. "self-hosted,linux")'
required: true
fallback-runner:
description: 'Comma-separated label list or single label for the fallback (e.g. "ubuntu-latest")'
required: true
github-token:
description: 'GitHub token with repo admin read permissions'
required: true
outputs:
use-runner:
description: "JSON array of labels you can feed straight into runs-on"
value: ${{ steps.pick.outputs.use-runner }}
runs:
using: "composite"
steps:
- name: Check self-hosted fleet
id: pick
shell: bash
env:
TOKEN: ${{ inputs.github-token }}
PRIMARY: ${{ inputs.primary-runner }}
FALLBACK: ${{ inputs.fallback-runner }}
run: |
# -------- helper -----------
to_json_array () {
local list="$1"; IFS=',' read -ra L <<<"$list"
printf '['; printf '"%s",' "${L[@]}"; printf ']'
}
# -------- query API ---------
repo="${{ github.repository }}"
runners=$(curl -s -H "Authorization: Bearer $TOKEN" \
-H "Accept: application/vnd.github+json" \
"https://api.github.com/repos/$repo/actions/runners?per_page=100")
# Debug: Print runners content
# echo "Runners response: $runners"
# Check if runners is null or empty
if [ -z "$runners" ] || [ "$runners" = "null" ]; then
echo "❌ Error: Unable to fetch runners or no runners found." >&2
exit 1
fi
# Process runners only if valid
IFS=',' read -ra WANT <<<"$PRIMARY"
online_found=0
while read -r row; do
labels=$(jq -r '.labels[].name' <<<"$row")
ok=1
for w in "${WANT[@]}"; do
grep -Fxq "$w" <<<"$labels" || { ok=0; break; }
done
[ "$ok" -eq 1 ] && { online_found=1; break; }
done < <(jq -c '.runners[] | select(.status=="online")' <<<"$runners")
if [ "$online_found" -eq 1 ]; then
echo "✅ Self-hosted runner online."
echo "use-runner=$(to_json_array "$PRIMARY")" >>"$GITHUB_OUTPUT"
else
echo "❌ No matching self-hosted runner online - using fallback."
echo "use-runner=$(to_json_array "$FALLBACK")" >>"$GITHUB_OUTPUT"
fi

View File

@@ -1,74 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Rustframe</title>
<link rel="icon" type="image/png" href="./rustframe_logo.png">
<style>
body {
font-family: Arial, sans-serif;
background-color: #2b2b2b;
color: #d4d4d4;
margin: 0;
padding: 0;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
}
main {
text-align: center;
padding: 20px;
background-color: #3c3c3c;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
max-width: 600px;
}
img {
max-width: 100px;
margin-bottom: 20px;
}
h1 {
/* logo is b35f20 */
color: #f8813f;
}
a {
color: #ff9a60;
text-decoration: none;
}
a:hover {
text-decoration: underline;
}
</style>
</head>
<body>
<main>
<h1>
<img src="./rustframe_logo.png" alt="Rustframe Logo"><br>
Rustframe
</h1>
<h2>A lightweight dataframe & math toolkit for Rust</h2>
<hr style="border: 1px solid #d4d4d4; margin: 20px 0;">
<p>
📚 <a href="https://magnus167.github.io/rustframe/docs">Docs</a> |
📊 <a href="https://magnus167.github.io/rustframe/benchmark-report/">Benchmarks</a>
<br><br>
🦀 <a href="https://crates.io/crates/rustframe">Crates.io</a> |
🔖 <a href="https://docs.rs/rustframe/latest/rustframe/">docs.rs</a>
<br><br>
🐙 <a href="https://github.com/Magnus167/rustframe">GitHub</a> |
🌐 <a href="https://gitea.nulltech.uk/Magnus167/rustframe">Gitea mirror</a>
</p>
</main>
</body>
</html>

View File

@@ -1,30 +0,0 @@
FROM ubuntu:latest
ARG RUNNER_VERSION="2.323.0"
# Prevents installdependencies.sh from prompting the user and blocking the image creation
ARG DEBIAN_FRONTEND=noninteractive
RUN apt update -y && apt upgrade -y && useradd -m docker
RUN apt install -y --no-install-recommends \
curl jq build-essential libssl-dev libffi-dev python3 python3-venv python3-dev python3-pip \
# dot net core dependencies
libicu74 libssl3 libkrb5-3 zlib1g libcurl4
RUN cd /home/docker && mkdir actions-runner && cd actions-runner \
&& curl -O -L https://github.com/actions/runner/releases/download/v${RUNNER_VERSION}/actions-runner-linux-arm64-${RUNNER_VERSION}.tar.gz \
&& tar xzf ./actions-runner-linux-arm64-${RUNNER_VERSION}.tar.gz
RUN chown -R docker ~docker && /home/docker/actions-runner/bin/installdependencies.sh
COPY entrypoint.sh entrypoint.sh
# make the script executable
RUN chmod +x entrypoint.sh
# since the config and run script for actions are not allowed to be run by root,
# set the user to "docker" so all subsequent commands are run as the docker user
USER docker
ENTRYPOINT ["./entrypoint.sh"]

View File

@@ -1,18 +0,0 @@
# docker-compose.yml
services:
github-runner:
build:
context: .
args:
RUNNER_VERSION: 2.323.0
# container_name commented to allow for multiple runners
# container_name: github-runner
env_file:
- .env
volumes:
- runner-work:/home/runner/actions-runner/_work
restart: unless-stopped
volumes:
runner-work:

View File

@@ -1,24 +0,0 @@
#!/bin/bash
REPOSITORY=$REPO
ACCESS_TOKEN=$GH_TOKEN
LABELS=$RUNNER_LABELS
# echo "REPO ${REPOSITORY}"
# echo "ACCESS_TOKEN ${ACCESS_TOKEN}"
REG_TOKEN=$(curl -X POST -H "Authorization: token ${ACCESS_TOKEN}" -H "Accept: application/vnd.github+json" https://api.github.com/repos/${REPOSITORY}/actions/runners/registration-token | jq .token --raw-output)
cd /home/docker/actions-runner
./config.sh --url https://github.com/${REPOSITORY} --token ${REG_TOKEN} --labels ${LABELS}
cleanup() {
echo "Removing runner..."
./config.sh remove --unattended --token ${REG_TOKEN}
}
trap 'cleanup; exit 130' INT
trap 'cleanup; exit 143' TERM
./run.sh & wait $!

View File

@@ -1,9 +0,0 @@
# Repository name
REPO="Magnus167/rustframe"
# GitHub runner token
GH_TOKEN="some_token_here"
# Labels for the runner
RUNNER_LABELS=self-hosted-linux,linux

View File

@@ -1,4 +0,0 @@
docker compose up -d --build
# docker compose up -d --build --scale github-runner=2

View File

@@ -1,45 +0,0 @@
FROM ubuntu:latest
ARG RUNNER_VERSION="2.323.0"
# Prevents installdependencies.sh from prompting the user and blocking the image creation
ARG DEBIAN_FRONTEND=noninteractive
RUN apt update -y && apt upgrade -y && useradd -m docker
RUN apt install -y --no-install-recommends \
curl jq git unzip \
# dev dependencies
build-essential libssl-dev libffi-dev python3 python3-venv python3-dev python3-pip \
# dot net core dependencies
libicu74 libssl3 libkrb5-3 zlib1g libcurl4 \
# Rust and Cargo dependencies
gcc cmake
# Install GitHub CLI
RUN curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg | dd of=/usr/share/keyrings/githubcli-archive-keyring.gpg \
&& chmod go+r /usr/share/keyrings/githubcli-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
&& apt update -y && apt install -y gh \
&& rm -rf /var/lib/apt/lists/*
# Install Rust and Cargo
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
ENV PATH="/home/docker/.cargo/bin:${PATH}"
ENV HOME="/home/docker"
RUN cd /home/docker && mkdir actions-runner && cd actions-runner \
&& curl -O -L https://github.com/actions/runner/releases/download/v${RUNNER_VERSION}/actions-runner-linux-x64-${RUNNER_VERSION}.tar.gz \
&& tar xzf ./actions-runner-linux-x64-${RUNNER_VERSION}.tar.gz
RUN chown -R docker ~docker && /home/docker/actions-runner/bin/installdependencies.sh
COPY entrypoint.sh entrypoint.sh
# make the script executable
RUN chmod +x entrypoint.sh
# since the config and run script for actions are not allowed to be run by root,
# set the user to "docker" so all subsequent commands are run as the docker user
USER docker
ENTRYPOINT ["./entrypoint.sh"]

View File

@@ -1,18 +0,0 @@
# docker-compose.yml
services:
github-runner:
build:
context: .
args:
RUNNER_VERSION: 2.323.0
# container_name commented to allow for multiple runners
# container_name: github-runner
env_file:
- .env
volumes:
- runner-work:/home/runner/actions-runner/_work
restart: unless-stopped
volumes:
runner-work:

View File

@@ -1,24 +0,0 @@
#!/bin/bash
REPOSITORY=$REPO
ACCESS_TOKEN=$GH_TOKEN
LABELS=$RUNNER_LABELS
# echo "REPO ${REPOSITORY}"
# echo "ACCESS_TOKEN ${ACCESS_TOKEN}"
REG_TOKEN=$(curl -X POST -H "Authorization: token ${ACCESS_TOKEN}" -H "Accept: application/vnd.github+json" https://api.github.com/repos/${REPOSITORY}/actions/runners/registration-token | jq .token --raw-output)
cd /home/docker/actions-runner
./config.sh --url https://github.com/${REPOSITORY} --token ${REG_TOKEN} --labels ${LABELS}
cleanup() {
echo "Removing runner..."
./config.sh remove --unattended --token ${REG_TOKEN}
}
trap 'cleanup; exit 130' INT
trap 'cleanup; exit 143' TERM
./run.sh & wait $!

View File

@@ -1,9 +0,0 @@
# Repository name
REPO="Magnus167/rustframe"
# GitHub runner token
GH_TOKEN="some_token_here"
# Labels for the runner
RUNNER_LABELS=self-hosted-linux,linux

View File

@@ -1,4 +0,0 @@
docker compose up -d --build
# docker compose up -d --build --scale github-runner=2

View File

@@ -1,426 +0,0 @@
# create_benchmark_table.py
import argparse
import json
import re
import sys
from pathlib import Path
from pprint import pprint
from collections import defaultdict
from typing import Dict, Any, Optional
import pandas as pd
import html # Import the html module for escaping
# Regular expression to parse "test_name (size)" format
DIR_PATTERN = re.compile(r"^(.*?) \((.*?)\)$")
# Standard location for criterion estimates relative to the benchmark dir
ESTIMATES_PATH_NEW = Path("new") / "estimates.json"
# Fallback location (older versions or baseline comparisons)
ESTIMATES_PATH_BASE = Path("base") / "estimates.json"
# Standard location for the HTML report relative to the benchmark's specific directory
REPORT_HTML_RELATIVE_PATH = Path("report") / "index.html"
def get_default_criterion_report_path() -> Path:
"""
Returns the default path for the Criterion benchmark report.
This is typically 'target/criterion'.
"""
return Path("target") / "criterion" / "report" / "index.html"
def load_criterion_reports(
criterion_root_dir: Path,
) -> Dict[str, Dict[str, Dict[str, Any]]]:
"""
Loads Criterion benchmark results from a specified directory and finds HTML paths.
Args:
criterion_root_dir: The Path object pointing to the main
'target/criterion' directory.
Returns:
A nested dictionary structured as:
{ test_name: { size: {'json': json_content, 'html_path': relative_html_path}, ... }, ... }
Returns an empty dict if the root directory is not found or empty.
"""
results: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(dict)
if not criterion_root_dir.is_dir():
print(
f"Error: Criterion root directory not found or is not a directory: {criterion_root_dir}",
file=sys.stderr,
)
return {}
print(f"Scanning for benchmark reports in: {criterion_root_dir}")
for item in criterion_root_dir.iterdir():
if not item.is_dir():
continue
match = DIR_PATTERN.match(item.name)
if not match:
continue
test_name = match.group(1).strip()
size = match.group(2).strip()
benchmark_dir_name = item.name
benchmark_dir_path = item
json_path: Optional[Path] = None
if (benchmark_dir_path / ESTIMATES_PATH_NEW).is_file():
json_path = benchmark_dir_path / ESTIMATES_PATH_NEW
elif (benchmark_dir_path / ESTIMATES_PATH_BASE).is_file():
json_path = benchmark_dir_path / ESTIMATES_PATH_BASE
html_path = benchmark_dir_path / REPORT_HTML_RELATIVE_PATH
if json_path is None or not json_path.is_file():
print(
f"Warning: Could not find estimates JSON in {benchmark_dir_path}. Skipping benchmark size '{test_name} ({size})'.",
file=sys.stderr,
)
continue
if not html_path.is_file():
print(
f"Warning: Could not find HTML report at expected location {html_path}. Skipping benchmark size '{test_name} ({size})'.",
file=sys.stderr,
)
continue
try:
with json_path.open("r", encoding="utf-8") as f:
json_data = json.load(f)
results[test_name][size] = {
"json": json_data,
"html_path_relative_to_criterion_root": str(
Path(benchmark_dir_name) / REPORT_HTML_RELATIVE_PATH
).replace("\\", "/"),
}
except json.JSONDecodeError:
print(f"Error: Failed to decode JSON from {json_path}", file=sys.stderr)
except IOError as e:
print(f"Error: Failed to read file {json_path}: {e}", file=sys.stderr)
except Exception as e:
print(
f"Error: An unexpected error occurred loading {json_path}: {e}",
file=sys.stderr,
)
return dict(results)
def format_nanoseconds(ns: float) -> str:
"""Formats nanoseconds into a human-readable string with units."""
if pd.isna(ns):
return "-"
if ns < 1_000:
return f"{ns:.2f} ns"
elif ns < 1_000_000:
return f"{ns / 1_000:.2f} µs"
elif ns < 1_000_000_000:
return f"{ns / 1_000_000:.2f} ms"
else:
return f"{ns / 1_000_000_000:.2f} s"
def generate_html_table_with_links(
results: Dict[str, Dict[str, Dict[str, Any]]], html_base_path: str
) -> str:
"""
Generates a full HTML page with a styled table from benchmark results.
"""
css_styles = """
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol";
line-height: 1.6;
margin: 0;
padding: 20px;
background-color: #f4f7f6;
color: #333;
}
.container {
max-width: 1200px;
margin: 20px auto;
padding: 20px;
background-color: #fff;
box-shadow: 0 0 15px rgba(0,0,0,0.1);
border-radius: 8px;
}
h1 {
color: #2c3e50;
text-align: center;
margin-bottom: 10px;
}
p.subtitle {
text-align: center;
margin-bottom: 8px;
color: #555;
font-size: 0.95em;
}
p.note {
text-align: center;
margin-bottom: 25px;
color: #777;
font-size: 0.85em;
}
.benchmark-table {
width: 100%;
border-collapse: collapse;
margin-top: 25px;
box-shadow: 0 2px 8px rgba(0,0,0,0.05);
}
.benchmark-table th, .benchmark-table td {
border: 1px solid #dfe6e9; /* Lighter border */
padding: 12px 15px;
}
.benchmark-table th {
background-color: #3498db; /* Primary blue */
color: #ffffff;
font-weight: 600; /* Slightly bolder */
text-transform: uppercase;
letter-spacing: 0.05em;
text-align: center; /* Center align headers */
}
.benchmark-table td {
text-align: right; /* Default for data cells (times) */
}
.benchmark-table td:first-child { /* Benchmark Name column */
font-weight: 500;
color: #2d3436;
text-align: left; /* Left align benchmark names */
}
.benchmark-table tbody tr:nth-child(even) {
background-color: #f8f9fa; /* Very light grey for even rows */
}
.benchmark-table tbody tr:hover {
background-color: #e9ecef; /* Slightly darker on hover */
}
.benchmark-table a {
color: #2980b9; /* Link blue */
text-decoration: none;
font-weight: 500;
}
.benchmark-table a:hover {
text-decoration: underline;
color: #1c5a81; /* Darker blue on hover */
}
.no-results {
text-align: center;
font-size: 1.2em;
color: #7f8c8d;
margin-top: 30px;
}
</style>
"""
html_doc_start = f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Criterion Benchmark Results</title>
{css_styles}
</head>
<body>
<div class="container">
<h1 id="criterion-benchmark-results">Criterion Benchmark Results</h1>
"""
html_doc_end = """
</div>
</body>
</html>"""
if not results:
return f"""{html_doc_start}
<p class="no-results">No benchmark results found or loaded.</p>
{html_doc_end}"""
all_sizes = sorted(
list(set(size for test_data in results.values() for size in test_data.keys())),
key=(lambda x: int(x.split("x")[0])),
)
all_test_names = sorted(list(results.keys()))
table_content = """
<p class="subtitle">Each cell links to the detailed Criterion.rs report for that specific benchmark size.</p>
<p class="note">Note: Values shown are the midpoint of the mean confidence interval, formatted for readability.</p>
<p class="note"><a href="report/index.html">[Switch to the standard Criterion.rs report]</a></p>
<table class="benchmark-table">
<thead>
<tr>
<th>Benchmark Name</th>
"""
for size in all_sizes:
table_content += f"<th>{html.escape(size)}</th>\n"
table_content += """
</tr>
</thead>
<tbody>
"""
for test_name in all_test_names:
table_content += f"<tr>\n"
table_content += f" <td>{html.escape(test_name)}</td>\n"
for size in all_sizes:
cell_data = results.get(test_name, {}).get(size)
mean_value = pd.NA
full_report_url = "#"
if (
cell_data
and "json" in cell_data
and "html_path_relative_to_criterion_root" in cell_data
):
try:
mean_data = cell_data["json"].get("mean")
if mean_data and "confidence_interval" in mean_data:
ci = mean_data["confidence_interval"]
if "lower_bound" in ci and "upper_bound" in ci:
lower, upper = ci["lower_bound"], ci["upper_bound"]
if isinstance(lower, (int, float)) and isinstance(
upper, (int, float)
):
mean_value = (lower + upper) / 2.0
else:
print(
f"Warning: Non-numeric bounds for {test_name} ({size}).",
file=sys.stderr,
)
else:
print(
f"Warning: Missing confidence_interval bounds for {test_name} ({size}).",
file=sys.stderr,
)
else:
print(
f"Warning: Missing 'mean' data for {test_name} ({size}).",
file=sys.stderr,
)
relative_report_path = cell_data[
"html_path_relative_to_criterion_root"
]
joined_path = Path(html_base_path) / relative_report_path
full_report_url = str(joined_path).replace("\\", "/")
except Exception as e:
print(
f"Error processing cell data for {test_name} ({size}): {e}",
file=sys.stderr,
)
formatted_mean = format_nanoseconds(mean_value)
if full_report_url and full_report_url != "#":
table_content += f' <td><a href="{html.escape(full_report_url)}">{html.escape(formatted_mean)}</a></td>\n'
else:
table_content += f" <td>{html.escape(formatted_mean)}</td>\n"
table_content += "</tr>\n"
table_content += """
</tbody>
</table>
"""
return f"{html_doc_start}{table_content}{html_doc_end}"
if __name__ == "__main__":
DEFAULT_CRITERION_PATH = "target/criterion"
DEFAULT_OUTPUT_FILE = "./target/criterion/index.html"
DEFAULT_HTML_BASE_PATH = ""
parser = argparse.ArgumentParser(
description="Load Criterion benchmark results from JSON files and generate an HTML table with links to reports."
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Perform a dry run without writing the HTML file.",
)
parser.add_argument(
"--criterion-dir",
type=str,
default=DEFAULT_CRITERION_PATH,
help=f"Path to the main 'target/criterion' directory (default: {DEFAULT_CRITERION_PATH}) containing benchmark data.",
)
parser.add_argument(
"--html-base-path",
type=str,
default=DEFAULT_HTML_BASE_PATH,
help=(
f"Prefix for HTML links to individual benchmark reports. "
f"This is prepended to each report's relative path (e.g., 'benchmark_name/report/index.html'). "
f"If the main output HTML (default: '{DEFAULT_OUTPUT_FILE}') is in the 'target/criterion/' directory, "
f"this should typically be empty (default: '{DEFAULT_HTML_BASE_PATH}'). "
),
)
parser.add_argument(
"--output-file",
type=str,
default=DEFAULT_OUTPUT_FILE,
help=f"Path to save the generated HTML summary report (default: {DEFAULT_OUTPUT_FILE}).",
)
args = parser.parse_args()
if args.dry_run:
print(
"Dry run mode: No files will be written. Use --dry-run to skip writing the HTML file."
)
sys.exit(0)
criterion_path = Path(args.criterion_dir)
output_file_path = Path(args.output_file)
try:
output_file_path.parent.mkdir(parents=True, exist_ok=True)
except OSError as e:
print(
f"Error: Could not create output directory {output_file_path.parent}: {e}",
file=sys.stderr,
)
sys.exit(1)
all_results = load_criterion_reports(criterion_path)
# Generate HTML output regardless of whether results were found (handles "no results" page)
html_output = generate_html_table_with_links(all_results, args.html_base_path)
if not all_results:
print("\nNo benchmark results found or loaded.")
# Fallthrough to write the "no results" page generated by generate_html_table_with_links
else:
print("\nSuccessfully loaded benchmark results.")
# pprint(all_results) # Uncomment for debugging
print(
f"Generating HTML report with links using HTML base path: '{args.html_base_path}'"
)
try:
with output_file_path.open("w", encoding="utf-8") as f:
f.write(html_output)
print(f"\nSuccessfully wrote HTML report to {output_file_path}")
if not all_results:
sys.exit(1) # Exit with error code if no results, though file is created
sys.exit(0)
except IOError as e:
print(f"Error writing HTML output to {output_file_path}: {e}", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"An unexpected error occurred while writing HTML: {e}", file=sys.stderr)
sys.exit(1)

View File

@@ -7,13 +7,9 @@ concurrency:
on: on:
push: push:
branches: [main] branches: [main]
# pull_request: # pull_request:
# branches: [main] # branches: [main]
workflow_dispatch: workflow_dispatch:
workflow_run:
workflows: ["run-benchmarks"]
types:
- completed
permissions: permissions:
contents: read contents: read
@@ -21,23 +17,8 @@ permissions:
pages: write pages: write
jobs: jobs:
pick-runner:
runs-on: ubuntu-latest
outputs:
runner: ${{ steps.choose.outputs.use-runner }}
steps:
- uses: actions/checkout@v4
- id: choose
uses: ./.github/actions/runner-fallback
with:
primary-runner: "self-hosted"
fallback-runner: "ubuntu-latest"
github-token: ${{ secrets.CUSTOM_GH_TOKEN }}
docs-and-testcov: docs-and-testcov:
needs: pick-runner runs-on: ubuntu-latest
runs-on: ${{ fromJson(needs.pick-runner.outputs.runner) }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@@ -48,14 +29,6 @@ jobs:
toolchain: stable toolchain: stable
override: true override: true
- name: Replace logo URL in README.md
env:
LOGO_URL: ${{ secrets.LOGO_URL }}
run: |
# replace with EXAMPLE.COM/LOGO
sed -i 's|.github/rustframe_logo.png|rustframe_logo.png|g' README.md
- name: Build documentation - name: Build documentation
run: cargo doc --no-deps --release run: cargo doc --no-deps --release
@@ -81,10 +54,12 @@ jobs:
fi fi
- name: Export tarpaulin coverage badge JSON - name: Export tarpaulin coverage badge JSON
# extract raw coverage and round to 2 decimal places
run: | run: |
# extract raw coverage
coverage=$(jq '.coverage' tarpaulin-report.json) coverage=$(jq '.coverage' tarpaulin-report.json)
# round to 2 decimal places
formatted=$(printf "%.2f" "$coverage") formatted=$(printf "%.2f" "$coverage")
# build the badge JSON using the pre-formatted string
jq --arg message "$formatted" \ jq --arg message "$formatted" \
'{schemaVersion:1, '{schemaVersion:1,
label:"tarpaulin-report", label:"tarpaulin-report",
@@ -104,77 +79,23 @@ jobs:
<(echo '{}') \ <(echo '{}') \
> last-commit-date.json > last-commit-date.json
- name: Download last available benchmark report
env:
GH_TOKEN: ${{ secrets.CUSTOM_GH_TOKEN }}
run: |
artifact_url=$(
curl -sSL \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${GH_TOKEN}" \
"https://api.github.com/repos/${{ github.repository }}/actions/artifacts" \
| jq -r '
.artifacts[]
| select(.name | startswith("benchmark-reports"))
| .archive_download_url
' \
| head -n 1
)
if [ -z "$artifact_url" ]; then
echo "No benchmark artifact found!"
mkdir -p benchmark-report
echo '<!DOCTYPE html><html><head><title>No Benchmarks</title></head><body><h1>No benchmarks available</h1></body></html>' > benchmark-report/index.html
exit 0
fi
curl -L -H "Authorization: Bearer ${GH_TOKEN}" \
"$artifact_url" -o benchmark-report.zip
# Print all files in the current directory
echo "Files in the current directory:"
ls -al
# check if the zip file is valid
if ! unzip -tq benchmark-report.zip; then
echo "benchmark-report.zip is invalid or corrupted!"
exit 1
fi
unzip -q benchmark-report.zip -d benchmark-report
# echo "<meta http-equiv=\"refresh\" content=\"0; url=report/index.html\">" > benchmark-report/index.html
- name: Copy files to output directory - name: Copy files to output directory
run: | run: |
# mkdir docs
mkdir -p target/doc/docs
mv target/doc/rustframe/* target/doc/docs/
mkdir output mkdir output
cp tarpaulin-report.html target/doc/docs/ cp tarpaulin-report.html target/doc/rustframe/
cp tarpaulin-report.json target/doc/docs/ cp tarpaulin-report.json target/doc/rustframe/
cp tarpaulin-badge.json target/doc/docs/ cp tarpaulin-badge.json target/doc/rustframe/
cp last-commit-date.json target/doc/docs/ cp last-commit-date.json target/doc/rustframe/
# cp -r .github target/doc/docs mkdir -p target/doc/rustframe/.github
cp .github/rustframe_logo.png target/doc/docs/ cp .github/rustframe_logo.png target/doc/rustframe/.github/
# echo "<meta http-equiv=\"refresh\" content=\"0; url=docs\">" > target/doc/index.html echo "<meta http-equiv=\"refresh\" content=\"0; url=rustframe\">" > target/doc/index.html
touch target/doc/.nojekyll
# copy the benchmark report to the output directory
cp -r benchmark-report target/doc/
- name: Add index.html to output directory
run: |
cp .github/htmldocs/index.html target/doc/index.html
cp .github/rustframe_logo.png target/doc/rustframe_logo.png
- name: Upload Pages artifact - name: Upload Pages artifact
# if: github.event_name == 'push' || github.event_name == 'workflow_dispatch' if: github.event_name == 'push'
uses: actions/upload-pages-artifact@v3 uses: actions/upload-pages-artifact@v3
with: with:
path: target/doc/ path: target/doc/
- name: Deploy to GitHub Pages - name: Deploy to GitHub Pages
# if: github.event_name == 'push' || github.event_name == 'workflow_dispatch' if: github.event_name == 'push'
uses: actions/deploy-pages@v4 uses: actions/deploy-pages@v4

View File

@@ -1,4 +1,4 @@
name: run-benchmarks name: Run benchmarks
on: on:
workflow_dispatch: workflow_dispatch:
@@ -7,22 +7,8 @@ on:
- main - main
jobs: jobs:
pick-runner:
runs-on: ubuntu-latest
outputs:
runner: ${{ steps.choose.outputs.use-runner }}
steps:
- uses: actions/checkout@v4
- id: choose
uses: ./.github/actions/runner-fallback
with:
primary-runner: "self-hosted"
fallback-runner: "ubuntu-latest"
github-token: ${{ secrets.CUSTOM_GH_TOKEN }}
run-benchmarks: run-benchmarks:
needs: pick-runner runs-on: ubuntu-latest
runs-on: ${{ fromJson(needs.pick-runner.outputs.runner) }}
steps: steps:
- name: Checkout code - name: Checkout code
@@ -33,31 +19,11 @@ jobs:
with: with:
toolchain: stable toolchain: stable
- name: Install Python
uses: actions/setup-python@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
- name: Setup venv
run: |
uv venv
uv pip install pandas
uv run .github/scripts/custom_benchmark_report.py --dry-run
- name: Run benchmarks - name: Run benchmarks
run: cargo bench --features bench run: cargo bench
- name: Generate custom benchmark reports
run: |
if [ -d ./target/criterion ]; then
echo "Found benchmark reports, generating custom report..."
else
echo "No benchmark reports found, skipping custom report generation."
exit 1
fi
uv run .github/scripts/custom_benchmark_report.py
- name: Upload benchmark reports - name: Upload benchmark reports
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v3
with: with:
name: benchmark-reports-${{ github.sha }} name: benchmark-reports-${{ github.sha }}
path: ./target/criterion/ path: ./target/criterion/

View File

@@ -11,51 +11,25 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
pick-runner:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
outputs:
runner: ${{ steps.choose.outputs.use-runner }}
steps:
- uses: actions/checkout@v4
- id: choose
uses: ./.github/actions/runner-fallback
with:
primary-runner: "self-hosted"
fallback-runner: "ubuntu-latest"
github-token: ${{ secrets.CUSTOM_GH_TOKEN }}
run-unit-tests: run-unit-tests:
needs: pick-runner
if: github.event.pull_request.draft == false if: github.event.pull_request.draft == false
name: run-unit-tests name: run-unit-tests
runs-on: ${{ fromJson(needs.pick-runner.outputs.runner) }} runs-on: ubuntu-latest
env: env:
CARGO_TERM_COLOR: always CARGO_TERM_COLOR: always
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Rust - name: Install Rust
uses: actions-rs/toolchain@v1 run: rustup update stable
with:
toolchain: stable
override: true
- name: Install cargo-llvm-cov - name: Install cargo-llvm-cov
uses: taiki-e/install-action@cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov
- name: Generate code coverage
- name: Run doctests run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
run: cargo test --doc --release - name: Run doc-tests
run: cargo test --doc --all-features --workspace --release
- name: Run unit tests with code coverage
run: cargo llvm-cov --release --lcov --output-path lcov.info
- name: Test docs generation - name: Test docs generation
run: cargo doc --no-deps --release run: cargo doc --no-deps --release
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
uses: codecov/codecov-action@v3 uses: codecov/codecov-action@v3
with: with:

2
.gitignore vendored
View File

@@ -15,5 +15,3 @@ data/
.vscode/ .vscode/
tarpaulin-report.* tarpaulin-report.*
.github/htmldocs/rustframe_logo.png

795
Cargo.lock generated Normal file
View File

@@ -0,0 +1,795 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi",
"libc",
"winapi",
]
[[package]]
name = "autocfg"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bumpalo"
version = "3.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf"
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cc"
version = "1.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e3a13707ac958681c13b39b458c073d0d9bc8a22cb1b2f4c8e55eb72c13f362"
dependencies = [
"shlex",
]
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-link",
]
[[package]]
name = "ciborium"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
dependencies = [
"ciborium-io",
"ciborium-ll",
"serde",
]
[[package]]
name = "ciborium-io"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
[[package]]
name = "ciborium-ll"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
dependencies = [
"ciborium-io",
"half",
]
[[package]]
name = "clap"
version = "3.2.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123"
dependencies = [
"bitflags",
"clap_lex",
"indexmap",
"textwrap",
]
[[package]]
name = "clap_lex"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5"
dependencies = [
"os_str_bytes",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "criterion"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7c76e09c1aae2bc52b3d2f29e13c6572553b30c4aa1b8a49fd70de6412654cb"
dependencies = [
"anes",
"atty",
"cast",
"ciborium",
"clap",
"criterion-plot",
"itertools",
"lazy_static",
"num-traits",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "crunchy"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929"
[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "half"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9"
dependencies = [
"cfg-if",
"crunchy",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]]
name = "iana-time-zone"
version = "0.1.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"log",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]]
name = "indexmap"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [
"autocfg",
"hashbrown",
]
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "js-sys"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f"
dependencies = [
"once_cell",
"wasm-bindgen",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "libc"
version = "0.2.172"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
[[package]]
name = "log"
version = "0.4.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
[[package]]
name = "memchr"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "once_cell"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "os_str_bytes"
version = "6.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]]
name = "proc-macro2"
version = "1.0.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rayon"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]]
name = "regex"
version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "rustframe"
version = "0.0.1-a.0"
dependencies = [
"chrono",
"criterion",
]
[[package]]
name = "rustversion"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2"
[[package]]
name = "ryu"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "serde"
version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.140"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]]
name = "shlex"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "syn"
version = "2.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "textwrap"
version = "0.16.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057"
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "unicode-ident"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]]
name = "wasm-bindgen"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5"
dependencies = [
"cfg-if",
"once_cell",
"rustversion",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6"
dependencies = [
"bumpalo",
"log",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
dependencies = [
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d"
dependencies = [
"unicode-ident",
]
[[package]]
name = "web-sys"
version = "0.3.77"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-core"
version = "0.61.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980"
dependencies = [
"windows-implement",
"windows-interface",
"windows-link",
"windows-result",
"windows-strings",
]
[[package]]
name = "windows-implement"
version = "0.60.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "windows-interface"
version = "0.59.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "windows-link"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38"
[[package]]
name = "windows-result"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252"
dependencies = [
"windows-link",
]
[[package]]
name = "windows-strings"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2ba9642430ee452d5a7aa78d72907ebe8cfda358e8cb7918a2050581322f97"
dependencies = [
"windows-link",
]
[[package]]
name = "windows-sys"
version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-targets"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "rustframe" name = "rustframe"
version = "0.0.1-a.20250716" version = "0.0.1-a.0"
edition = "2021" edition = "2021"
license = "GPL-3.0-or-later" license = "GPL-3.0-or-later"
readme = "README.md" readme = "README.md"
@@ -13,13 +13,10 @@ crate-type = ["cdylib", "lib"]
[dependencies] [dependencies]
chrono = "^0.4.10" chrono = "^0.4.10"
criterion = { version = "0.5", features = ["html_reports"], optional = true }
rand = "^0.9.1"
[features] [dev-dependencies]
bench = ["dep:criterion"] criterion = { version = "0.4", features = ["html_reports"] }
[[bench]] [[bench]]
name = "benchmarks" name = "benchmarks"
harness = false harness = false
required-features = ["bench"]

144
README.md
View File

@@ -1,68 +1,38 @@
# rustframe
<!-- # <img align="center" alt="Rustframe" src=".github/rustframe_logo.png" height="50px" /> rustframe --> # <img align="center" alt="Rustframe" src=".github/rustframe_logo.png" height="50" /> rustframe
<!-- though the centre tag doesn't work as it would normally, it achieves the desired effect --> <!-- though the centre tag doesn't work as it would noramlly, it achieves the desired effect -->
📚 [Docs](https://magnus167.github.io/rustframe/) | 🐙 [GitHub](https://github.com/Magnus167/rustframe) | 🌐 [Gitea mirror](https://gitea.nulltech.uk/Magnus167/rustframe) | 🦀 [Crates.io](https://crates.io/crates/rustframe) | 🔖 [docs.rs](https://docs.rs/rustframe/latest/rustframe/) 📚 [Docs](https://magnus167.github.io/rustframe/) | 🐙 [GitHub](https://github.com/Magnus167/rustframe) | 🌐 [Gitea mirror](https://gitea.nulltech.uk/Magnus167/rustframe) | 🦀 [Crates.io](https://crates.io/crates/rustframe) | 🔖 [docs.rs](https://docs.rs/rustframe/latest/rustframe/)
<!-- [![Last commit](https://img.shields.io/endpoint?url=https://magnus167.github.io/rustframe/rustframe/last-commit-date.json)](https://github.com/Magnus167/rustframe) --> <!-- [![Last commit](https://img.shields.io/endpoint?url=https://magnus167.github.io/rustframe/rustframe/last-commit-date.json)](https://github.com/Magnus167/rustframe) -->
[![codecov](https://codecov.io/gh/Magnus167/rustframe/graph/badge.svg?token=J7ULJEFTVI)](https://codecov.io/gh/Magnus167/rustframe) [![codecov](https://codecov.io/gh/Magnus167/rustframe/graph/badge.svg?token=J7ULJEFTVI)](https://codecov.io/gh/Magnus167/rustframe)
[![Coverage](https://img.shields.io/endpoint?url=https://magnus167.github.io/rustframe/docs/tarpaulin-badge.json)](https://magnus167.github.io/rustframe/docs/tarpaulin-report.html) [![Coverage](https://img.shields.io/endpoint?url=https://magnus167.github.io/rustframe/rustframe/tarpaulin-badge.json)](https://magnus167.github.io/rustframe/rustframe/tarpaulin-report.html)
--- ---
## Rustframe: _A lightweight dataframe & math toolkit for Rust_ ## Rustframe: *A lightweight dataframe & math toolkit for Rust*
Rustframe provides intuitive dataframe, matrix, and series operations for data analysis and manipulation. Rustframe provides intuitive dataframe, matrix, and series operations small-to-mid scale data analysis and manipulation.
Rustframe keeps things simple, safe, and readable. It is handy for quick numeric experiments and small analytical tasks as well as for educational purposes. It is designed to be easy to use and understand, with a clean API implemented in 100% safe Rust. Rustframe keeps things simple, safe, and readable. It is handy for quick numeric experiments and small analytical tasks, but it is **not** meant to compete with powerhouse crates like `polars` or `ndarray`.
Rustframe is an educational project, and is not intended for production use. It is **not** meant to compete with powerhouse crates like `polars` or `ndarray`. It is a work in progress, and the API is subject to change. There are no guarantees of stability or performance, and it is not optimized for large datasets or high-performance computing.
### What it offers ### What it offers
- **Matrix operations** - Element-wise arithmetic, boolean logic, transpose, and more. - **Math that reads like math** - elementwise `+`, ``, `×`, `÷` on entire frames or scalars.
- **Math that reads like math** - element-wise `+`, ``, `×`, `÷` on entire frames or scalars. - **Broadcast & reduce** - sum, product, any/all across rows or columns without boilerplate.
- **Frames** - Column major data structure for single-type data, with labeled columns and typed row indices. - **Boolean masks made simple** - chain comparisons, combine with `&`/`|`, get a tidy `BoolMatrix` back.
- **Compute module** - Implements various statistical computations and machine learning models. - **Datecentric row index** - businessday ranges and calendar slicing built in.
- **Pure safe Rust** - 100% safe, zero `unsafe`.
- **[Coming Soon]** _DataFrame_ - Multi-type data structure for heterogeneous data, with labeled columns and typed row indices.
- **[Coming Soon]** _Random number utils_ - Random number generation utilities for statistical sampling and simulations. (Currently using the [`rand`](https://crates.io/crates/rand) crate.)
#### Matrix and Frame functionality
- **Matrix operations** - Element-wise arithmetic, boolean logic, transpose, and more.
- **Frame operations** - Column manipulation, sorting, and more.
#### Compute Module
The `compute` module provides implementations for various statistical computations and machine learning models.
**Statistics, Data Analysis, and Machine Learning:**
- Correlation analysis
- Descriptive statistics
- Distributions
- Inferential statistics
- Dense Neural Networks
- Gaussian Naive Bayes
- K-Means Clustering
- Linear Regression
- Logistic Regression
- Principal Component Analysis
### Heads up ### Heads up
- **Not memoryefficient (yet)** - footprint needs work. - **Not memoryefficient (yet)** - footprint needs work.
- **The feature set is still limited** - expect missing pieces. - **Feature set still small** - expect missing pieces.
### Somewhere down the line ### On the horizon
- Optional GPU acceleration (Vulkan or similar) for heavier workloads. - Optional GPU help (Vulkan or similar) for heavier workloads.
- Straightforward Python bindings using `pyo3`. - Straightforward Python bindings using `pyo3`.
--- ---
@@ -74,16 +44,17 @@ use chrono::NaiveDate;
use rustframe::{ use rustframe::{
frame::{Frame, RowIndex}, frame::{Frame, RowIndex},
matrix::{BoolOps, Matrix, SeriesOps}, matrix::{BoolOps, Matrix, SeriesOps},
utils::{DateFreq, BDatesList}, utils::{BDateFreq, BDatesList},
}; };
let n_periods = 4; let n_periods = 4;
// Four business days starting 2024-01-02 // Four business days starting 20240102
let dates: Vec<NaiveDate> = let dates: Vec<NaiveDate> =
BDatesList::from_n_periods("2024-01-02".to_string(), DateFreq::Daily, n_periods) BDatesList::from_n_periods("2024-01-02".to_string(), BDateFreq::Daily, n_periods)
.unwrap() .unwrap()
.list().unwrap(); .list()
.unwrap();
let col_names: Vec<String> = vec!["a".to_string(), "b".to_string()]; let col_names: Vec<String> = vec!["a".to_string(), "b".to_string()];
@@ -114,88 +85,17 @@ let result: Matrix<f64> = result / 2.0; // divide by scalar
let check: bool = result.eq_elem(ma.clone()).all(); let check: bool = result.eq_elem(ma.clone()).all();
assert!(check); assert!(check);
// Alternatively: // The above math can also be written as:
let check: bool = (&(&(&(&ma + 1.0) - 1.0) * 2.0) / 2.0) let check: bool = (&(&(&(&ma + 1.0) - 1.0) * 2.0) / 2.0)
.eq_elem(ma.clone()) .eq_elem(ma.clone())
.all(); .all();
assert!(check); assert!(check);
// or even as: // The above math can also be written as:
let check: bool = ((((ma.clone() + 1.0) - 1.0) * 2.0) / 2.0) let check: bool = ((((ma.clone() + 1.0) - 1.0) * 2.0) / 2.0)
.eq_elem(ma.clone()) .eq_elem(ma)
.all(); .all();
assert!(check); assert!(check);
// Matrix multiplication
let mc: Matrix<f64> = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
let md: Matrix<f64> = Matrix::from_cols(vec![vec![5.0, 6.0], vec![7.0, 8.0]]);
let mul_result: Matrix<f64> = mc.matrix_mul(&md);
// Expected:
// 1*5 + 3*6 = 5 + 18 = 23
// 2*5 + 4*6 = 10 + 24 = 34
// 1*7 + 3*8 = 7 + 24 = 31
// 2*7 + 4*8 = 14 + 32 = 46
assert_eq!(mul_result.data(), &[23.0, 34.0, 31.0, 46.0]);
// Dot product (alias for matrix_mul for FloatMatrix)
let dot_result: Matrix<f64> = mc.dot(&md);
assert_eq!(dot_result, mul_result);
// Transpose
let original_matrix: Matrix<f64> = Matrix::from_cols(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
// Original:
// 1 4
// 2 5
// 3 6
let transposed_matrix: Matrix<f64> = original_matrix.transpose();
// Transposed:
// 1 2 3
// 4 5 6
assert_eq!(transposed_matrix.rows(), 2);
assert_eq!(transposed_matrix.cols(), 3);
assert_eq!(transposed_matrix.data(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
// Map
let matrix = Matrix::from_cols(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
// Map function to double each value
let mapped_matrix = matrix.map(|x| x * 2.0);
// Expected data after mapping
// 2 8
// 4 10
// 6 12
assert_eq!(mapped_matrix.data(), &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
// Zip
let a = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); // 2x2 matrix
let b = Matrix::from_cols(vec![vec![5.0, 6.0], vec![7.0, 8.0]]); // 2x2 matrix
// Zip function to add corresponding elements
let zipped_matrix = a.zip(&b, |x, y| x + y);
// Expected data after zipping
// 6 10
// 8 12
assert_eq!(zipped_matrix.data(), &[6.0, 8.0, 10.0, 12.0]);
```
### More examples
See the [examples](./examples/) directory for some demonstrations of Rustframe's syntax and functionality.
To run the examples, use:
```bash
cargo run --example <example_name>
```
E.g. to run the `game_of_life` example:
```bash
cargo run --example game_of_life
```
### Running benchmarks
To run the benchmarks, use:
```bash
cargo bench --features "bench"
``` ```

View File

@@ -1,22 +1,35 @@
// Combined benchmarks // Combined benchmarks for rustframe
use chrono::NaiveDate; use chrono::NaiveDate;
use criterion::{criterion_group, criterion_main, Criterion}; use criterion::{criterion_group, criterion_main, Criterion};
// Import Duration for measurement_time and warm_up_time
use rustframe::{ use rustframe::{
frame::{Frame, RowIndex}, frame::{Frame, RowIndex},
matrix::{Axis, BoolMatrix, Matrix, SeriesOps}, matrix::{BoolMatrix, Matrix},
utils::{DateFreq, DatesList}, utils::{BDateFreq, BDatesList},
}; };
use std::time::Duration; use std::time::Duration;
// Define size categories // You can define a custom Criterion configuration function
const SIZES_SMALL: [usize; 1] = [1]; // This will be passed to the criterion_group! macro
const SIZES_MEDIUM: [usize; 3] = [100, 250, 500]; pub fn for_short_runs() -> Criterion {
const SIZES_LARGE: [usize; 1] = [1000]; Criterion::default()
// (samples != total iterations)
// limits the number of statistical data points.
.sample_size(50)
// measurement time per sample
.measurement_time(Duration::from_millis(2000))
// reduce warm-up time as well for faster overall run
.warm_up_time(Duration::from_millis(50))
// You could also make it much shorter if needed, e.g., 50ms measurement, 100ms warm-up
// .measurement_time(Duration::from_millis(50))
// .warm_up_time(Duration::from_millis(100))
}
// Modified benchmark functions to accept a slice of sizes fn bool_matrix_operations_benchmark(c: &mut Criterion) {
fn bool_matrix_operations_benchmark(c: &mut Criterion, sizes: &[usize]) { let sizes = [1, 100, 1000];
for &size in sizes { // let sizes = [1000];
for &size in &sizes {
let data1: Vec<bool> = (0..size * size).map(|x| x % 2 == 0).collect(); let data1: Vec<bool> = (0..size * size).map(|x| x % 2 == 0).collect();
let data2: Vec<bool> = (0..size * size).map(|x| x % 3 == 0).collect(); let data2: Vec<bool> = (0..size * size).map(|x| x % 3 == 0).collect();
let bm1 = BoolMatrix::from_vec(data1.clone(), size, size); let bm1 = BoolMatrix::from_vec(data1.clone(), size, size);
@@ -48,8 +61,11 @@ fn bool_matrix_operations_benchmark(c: &mut Criterion, sizes: &[usize]) {
} }
} }
fn matrix_boolean_operations_benchmark(c: &mut Criterion, sizes: &[usize]) { fn matrix_boolean_operations_benchmark(c: &mut Criterion) {
for &size in sizes { let sizes = [1, 100, 1000];
// let sizes = [1000];
for &size in &sizes {
let data1: Vec<bool> = (0..size * size).map(|x| x % 2 == 0).collect(); let data1: Vec<bool> = (0..size * size).map(|x| x % 2 == 0).collect();
let data2: Vec<bool> = (0..size * size).map(|x| x % 3 == 0).collect(); let data2: Vec<bool> = (0..size * size).map(|x| x % 3 == 0).collect();
let bm1 = BoolMatrix::from_vec(data1.clone(), size, size); let bm1 = BoolMatrix::from_vec(data1.clone(), size, size);
@@ -81,8 +97,11 @@ fn matrix_boolean_operations_benchmark(c: &mut Criterion, sizes: &[usize]) {
} }
} }
fn matrix_operations_benchmark(c: &mut Criterion, sizes: &[usize]) { fn matrix_operations_benchmark(c: &mut Criterion) {
for &size in sizes { let sizes = [1, 100, 1000];
// let sizes = [1000];
for &size in &sizes {
let data: Vec<f64> = (0..size * size).map(|x| x as f64).collect(); let data: Vec<f64> = (0..size * size).map(|x| x as f64).collect();
let ma = Matrix::from_vec(data.clone(), size, size); let ma = Matrix::from_vec(data.clone(), size, size);
@@ -109,50 +128,10 @@ fn matrix_operations_benchmark(c: &mut Criterion, sizes: &[usize]) {
let _result = &ma / 2.0; let _result = &ma / 2.0;
}); });
}); });
c.bench_function(
&format!("matrix matrix_multiply ({}x{})", size, size),
|b| {
b.iter(|| {
let _result = ma.matrix_mul(&ma);
});
},
);
c.bench_function(&format!("matrix sum_horizontal ({}x{})", size, size), |b| {
b.iter(|| {
let _result = ma.sum_horizontal();
});
});
c.bench_function(&format!("matrix sum_vertical ({}x{})", size, size), |b| {
b.iter(|| {
let _result = ma.sum_vertical();
});
});
c.bench_function(
&format!("matrix prod_horizontal ({}x{})", size, size),
|b| {
b.iter(|| {
let _result = ma.prod_horizontal();
});
},
);
c.bench_function(&format!("matrix prod_vertical ({}x{})", size, size), |b| {
b.iter(|| {
let _result = ma.prod_vertical();
});
});
c.bench_function(&format!("matrix apply_axis ({}x{})", size, size), |b| {
b.iter(|| {
let _result = ma.apply_axis(Axis::Col, |col| col.iter().sum::<f64>());
});
});
c.bench_function(&format!("matrix transpose ({}x{})", size, size), |b| {
b.iter(|| {
let _result = ma.transpose();
});
});
} }
for &size in sizes { // Benchmarking matrix addition
for &size in &sizes {
let data1: Vec<f64> = (0..size * size).map(|x| x as f64).collect(); let data1: Vec<f64> = (0..size * size).map(|x| x as f64).collect();
let data2: Vec<f64> = (0..size * size).map(|x| (x + 1) as f64).collect(); let data2: Vec<f64> = (0..size * size).map(|x| (x + 1) as f64).collect();
let ma = Matrix::from_vec(data1.clone(), size, size); let ma = Matrix::from_vec(data1.clone(), size, size);
@@ -184,151 +163,44 @@ fn matrix_operations_benchmark(c: &mut Criterion, sizes: &[usize]) {
} }
} }
fn generate_frame(size: usize) -> Frame<f64> { fn benchmark_frame_operations(c: &mut Criterion) {
let data: Vec<f64> = (0..size * size).map(|x| x as f64).collect(); let n_periods = 1000;
let n_cols = 1000;
let dates: Vec<NaiveDate> = let dates: Vec<NaiveDate> =
DatesList::from_n_periods("2000-01-01".to_string(), DateFreq::Daily, size) BDatesList::from_n_periods("2024-01-02".to_string(), BDateFreq::Daily, n_periods)
.unwrap() .unwrap()
.list() .list()
.unwrap(); .unwrap();
let col_names: Vec<String> = (1..=size).map(|i| format!("col_{}", i)).collect();
Frame::new(
Matrix::from_vec(data.clone(), size, size),
col_names,
Some(RowIndex::Date(dates)),
)
}
fn benchmark_frame_operations(c: &mut Criterion, sizes: &[usize]) { // let col_names= str(i) for i in range(1, 1000)
for &size in sizes { let col_names: Vec<String> = (1..=n_cols).map(|i| format!("col_{}", i)).collect();
let fa = generate_frame(size);
let fb = generate_frame(size);
c.bench_function(&format!("frame add ({}x{})", size, size), |b| { let data1: Vec<f64> = (0..n_periods * n_cols).map(|x| x as f64).collect();
b.iter(|| { let data2: Vec<f64> = (0..n_periods * n_cols).map(|x| (x + 1) as f64).collect();
let _result = &fa + &fb; let ma = Matrix::from_vec(data1.clone(), n_periods, n_cols);
}); let mb = Matrix::from_vec(data2.clone(), n_periods, n_cols);
});
c.bench_function(&format!("frame subtract ({}x{})", size, size), |b| { let fa = Frame::new(
b.iter(|| { ma.clone(),
let _result = &fa - &fb; col_names.clone(),
}); Some(RowIndex::Date(dates.clone())),
}); );
let fb = Frame::new(mb, col_names, Some(RowIndex::Date(dates)));
c.bench_function(&format!("frame multiply ({}x{})", size, size), |b| { c.bench_function("frame element-wise multiply (1000x1000)", |b| {
b.iter(|| { b.iter(|| {
let _result = &fa * &fb; let _result = &fa * &fb;
}); });
}); });
c.bench_function(&format!("frame divide ({}x{})", size, size), |b| {
b.iter(|| {
let _result = &fa / &fb;
});
});
c.bench_function(&format!("frame matrix_multiply ({}x{})", size, size), |b| {
b.iter(|| {
let _result = fa.matrix_mul(&fb);
});
});
c.bench_function(&format!("frame sum_horizontal ({}x{})", size, size), |b| {
b.iter(|| {
let _result = fa.sum_horizontal();
});
});
c.bench_function(&format!("frame sum_vertical ({}x{})", size, size), |b| {
b.iter(|| {
let _result = fa.sum_vertical();
});
});
c.bench_function(&format!("frame prod_horizontal ({}x{})", size, size), |b| {
b.iter(|| {
let _result = fa.prod_horizontal();
});
});
c.bench_function(&format!("frame prod_vertical ({}x{})", size, size), |b| {
b.iter(|| {
let _result = fa.prod_vertical();
});
});
c.bench_function(&format!("frame apply_axis ({}x{})", size, size), |b| {
b.iter(|| {
let _result = fa.apply_axis(Axis::Col, |col| col.iter().sum::<f64>());
});
});
c.bench_function(&format!("frame transpose ({}x{})", size, size), |b| {
b.iter(|| {
let _result = fa.transpose();
});
});
}
}
// Runner functions for each size category
fn run_benchmarks_small(c: &mut Criterion) {
bool_matrix_operations_benchmark(c, &SIZES_SMALL);
matrix_boolean_operations_benchmark(c, &SIZES_SMALL);
matrix_operations_benchmark(c, &SIZES_SMALL);
benchmark_frame_operations(c, &SIZES_SMALL);
}
fn run_benchmarks_medium(c: &mut Criterion) {
bool_matrix_operations_benchmark(c, &SIZES_MEDIUM);
matrix_boolean_operations_benchmark(c, &SIZES_MEDIUM);
matrix_operations_benchmark(c, &SIZES_MEDIUM);
benchmark_frame_operations(c, &SIZES_MEDIUM);
}
fn run_benchmarks_large(c: &mut Criterion) {
bool_matrix_operations_benchmark(c, &SIZES_LARGE);
matrix_boolean_operations_benchmark(c, &SIZES_LARGE);
matrix_operations_benchmark(c, &SIZES_LARGE);
benchmark_frame_operations(c, &SIZES_LARGE);
}
// Configuration functions for different size categories
fn config_small_arrays() -> Criterion {
Criterion::default()
.sample_size(500)
.measurement_time(Duration::from_millis(100))
.warm_up_time(Duration::from_millis(5))
}
fn config_medium_arrays() -> Criterion {
Criterion::default()
.sample_size(100)
.measurement_time(Duration::from_millis(2000))
.warm_up_time(Duration::from_millis(100))
}
fn config_large_arrays() -> Criterion {
Criterion::default()
.sample_size(50)
.measurement_time(Duration::from_millis(5000))
.warm_up_time(Duration::from_millis(200))
} }
// Define the criterion group and pass the custom configuration function
criterion_group!( criterion_group!(
name = benches_small_arrays; name = combined_benches;
config = config_small_arrays(); config = for_short_runs(); // Use the custom configuration here
targets = run_benchmarks_small targets = bool_matrix_operations_benchmark,
); matrix_boolean_operations_benchmark,
criterion_group!( matrix_operations_benchmark,
name = benches_medium_arrays; benchmark_frame_operations
config = config_medium_arrays();
targets = run_benchmarks_medium
);
criterion_group!(
name = benches_large_arrays;
config = config_large_arrays();
targets = run_benchmarks_large
);
criterion_main!(
benches_small_arrays,
benches_medium_arrays,
benches_large_arrays
); );
criterion_main!(combined_benches);

View File

@@ -1,444 +0,0 @@
use rand::{self, Rng};
use rustframe::matrix::{BoolMatrix, BoolOps, IntMatrix, Matrix};
use std::{thread, time};
const BOARD_SIZE: usize = 50; // Size of the board (50x50)
const TICK_DURATION_MS: u64 = 10; // Milliseconds per frame
fn main() {
// Initialize the game board.
// This demonstrates `BoolMatrix::from_vec`.
let mut current_board =
BoolMatrix::from_vec(vec![false; BOARD_SIZE * BOARD_SIZE], BOARD_SIZE, BOARD_SIZE);
let primes = generate_primes((BOARD_SIZE * BOARD_SIZE) as i32);
add_simulated_activity(&mut current_board, BOARD_SIZE);
let mut generation_count: u32 = 0;
// `previous_board_state` will store a clone of the board.
// This demonstrates `Matrix::clone()` and later `PartialEq` for `Matrix`.
let mut previous_board_state: Option<BoolMatrix> = None;
let mut board_hashes = Vec::new();
// let mut print_board_bool = true;
let mut print_bool_int = 0;
loop {
// print!("{}[2J", 27 as char); // Clear screen and move cursor to top-left
// if print_board_bool {
if print_bool_int % 10 == 0 {
print!("{}[2J", 27 as char);
println!("Conway's Game of Life - Generation: {}", generation_count);
print_board(&current_board);
println!("Alive cells: {}", &current_board.count());
// print_board_bool = false;
print_bool_int = 0;
} else {
// print_board_bool = true;
print_bool_int += 1;
}
// `current_board.count()` demonstrates a method from `BoolOps`.
board_hashes.push(hash_board(&current_board, primes.clone()));
if detect_stable_state(&current_board, &previous_board_state) {
println!(
"\nStable state detected at generation {}.",
generation_count
);
add_simulated_activity(&mut current_board, BOARD_SIZE);
}
if detect_repeating_state(&mut board_hashes) {
println!(
"\nRepeating state detected at generation {}.",
generation_count
);
add_simulated_activity(&mut current_board, BOARD_SIZE);
}
if !&current_board.any() {
println!("\nExtinction at generation {}.", generation_count);
add_simulated_activity(&mut current_board, BOARD_SIZE);
}
// `current_board.clone()` demonstrates `Clone` for `Matrix`.
previous_board_state = Some(current_board.clone());
// This is the core call to your game logic.
let next_board = game_of_life_next_frame(&current_board);
current_board = next_board;
generation_count += 1;
thread::sleep(time::Duration::from_millis(TICK_DURATION_MS));
// if generation_count > 500 { // Optional limit
// println!("\nReached generation limit.");
// break;
// }
}
}
/// Prints the Game of Life board to the console.
///
/// - `board`: A reference to the `BoolMatrix` representing the current game state.
/// This function demonstrates `board.rows()`, `board.cols()`, and `board[(r, c)]` (Index trait).
fn print_board(board: &BoolMatrix) {
let mut print_str = String::new();
print_str.push_str("+");
for _ in 0..board.cols() {
print_str.push_str("--");
}
print_str.push_str("+\n");
for r in 0..board.rows() {
print_str.push_str("| ");
for c in 0..board.cols() {
if board[(r, c)] {
// Using Index trait for Matrix<bool>
print_str.push_str("██");
} else {
print_str.push_str(" ");
}
}
print_str.push_str(" |\n");
}
print_str.push_str("+");
for _ in 0..board.cols() {
print_str.push_str("--");
}
print_str.push_str("+\n\n");
print!("{}", print_str);
}
/// Helper function to create a shifted version of the game board.
/// (Using the version provided by the user)
///
/// - `game`: The current state of the Game of Life as a `BoolMatrix`.
/// - `dr`: The row shift (delta row). Positive shifts down, negative shifts up.
/// - `dc`: The column shift (delta column). Positive shifts right, negative shifts left.
///
/// Returns an `IntMatrix` of the same dimensions as `game`.
/// - Cells in the shifted matrix get value `1` if the corresponding source cell in `game` was `true` (alive).
/// - Cells that would source from outside `game`'s bounds (due to the shift) get value `0`.
fn get_shifted_neighbor_layer(game: &BoolMatrix, dr: isize, dc: isize) -> IntMatrix {
let rows = game.rows();
let cols = game.cols();
if rows == 0 || cols == 0 {
// Handle 0x0 case, other 0-dim cases panic in Matrix::from_vec
return IntMatrix::from_vec(vec![], 0, 0);
}
// Initialize with a matrix of 0s using from_vec.
// This demonstrates creating an IntMatrix and then populating it.
let mut shifted_layer = IntMatrix::from_vec(vec![0i32; rows * cols], rows, cols);
for r_target in 0..rows {
// Iterate over cells in the *new* (target) shifted matrix
for c_target in 0..cols {
// Calculate where this target cell would have come from in the *original* game matrix
let r_source = r_target as isize - dr;
let c_source = c_target as isize - dc;
// Check if the source coordinates are within the bounds of the original game matrix
if r_source >= 0
&& r_source < rows as isize
&& c_source >= 0
&& c_source < cols as isize
{
// If the source cell in the original game was alive...
if game[(r_source as usize, c_source as usize)] {
// Demonstrates Index access on BoolMatrix
// ...then this cell in the shifted layer is 1.
shifted_layer[(r_target, c_target)] = 1; // Demonstrates IndexMut access on IntMatrix
}
}
// Else (source is out of bounds): it remains 0, as initialized.
}
}
shifted_layer // Return the constructed IntMatrix
}
/// Calculates the next generation of Conway's Game of Life.
///
/// This implementation uses a broadcast-like approach by creating shifted layers
/// for each neighbor and summing them up, then applying rules element-wise.
///
/// - `current_game`: A `&BoolMatrix` representing the current state (true=alive).
///
/// Returns: A new `BoolMatrix` for the next generation.
pub fn game_of_life_next_frame(current_game: &BoolMatrix) -> BoolMatrix {
let rows = current_game.rows();
let cols = current_game.cols();
if rows == 0 && cols == 0 {
return BoolMatrix::from_vec(vec![], 0, 0); // Return an empty BoolMatrix
}
// Assuming valid non-empty dimensions (e.g., 25x25) as per typical GOL.
// Your Matrix::from_vec would panic for other invalid 0-dim cases.
// Define the 8 neighbor offsets (row_delta, col_delta)
let neighbor_offsets: [(isize, isize); 8] = [
(-1, -1),
(-1, 0),
(-1, 1), // Top row (NW, N, NE)
(0, -1),
(0, 1), // Middle row (W, E)
(1, -1),
(1, 0),
(1, 1), // Bottom row (SW, S, SE)
];
// 1. Initialize `neighbor_counts` with the first shifted layer.
// This demonstrates creating an IntMatrix from a function and using it as a base.
let (first_dr, first_dc) = neighbor_offsets[0];
let mut neighbor_counts = get_shifted_neighbor_layer(current_game, first_dr, first_dc);
// 2. Add the remaining 7 neighbor layers.
// This demonstrates element-wise addition of matrices (`Matrix + Matrix`).
for i in 1..neighbor_offsets.len() {
let (dr, dc) = neighbor_offsets[i];
let next_neighbor_layer = get_shifted_neighbor_layer(current_game, dr, dc);
// `neighbor_counts` (owned IntMatrix) + `next_neighbor_layer` (owned IntMatrix)
// uses `impl Add for Matrix`, consumes both, returns new owned `IntMatrix`.
neighbor_counts = neighbor_counts + next_neighbor_layer;
}
// 3. Apply Game of Life rules using element-wise operations.
// Rule: Survival or Birth based on neighbor counts.
// A cell is alive in the next generation if:
// (it's currently alive AND has 2 or 3 neighbors) OR
// (it's currently dead AND has exactly 3 neighbors)
// `neighbor_counts.eq_elem(scalar)`:
// Demonstrates element-wise comparison of a Matrix with a scalar (broadcast).
// Returns an owned `BoolMatrix`.
let has_2_neighbors = neighbor_counts.eq_elem(2);
let has_3_neighbors = neighbor_counts.eq_elem(3); // This will be reused
// `has_2_neighbors | has_3_neighbors`:
// Demonstrates element-wise OR (`Matrix<bool> | Matrix<bool>`).
// Consumes both operands, returns an owned `BoolMatrix`.
let has_2_or_3_neighbors = has_2_neighbors | has_3_neighbors.clone(); // Clone has_3_neighbors as it's used again
// `current_game & &has_2_or_3_neighbors`:
// `current_game` is `&BoolMatrix`. `has_2_or_3_neighbors` is owned.
// Demonstrates element-wise AND (`&Matrix<bool> & &Matrix<bool>`).
// Borrows both operands, returns an owned `BoolMatrix`.
let survives = current_game & &has_2_or_3_neighbors;
// `!current_game`:
// Demonstrates element-wise NOT (`!&Matrix<bool>`).
// Borrows operand, returns an owned `BoolMatrix`.
let is_dead = !current_game;
// `is_dead & &has_3_neighbors`:
// `is_dead` is owned. `has_3_neighbors` is owned.
// Demonstrates element-wise AND (`Matrix<bool> & &Matrix<bool>`).
// Consumes `is_dead`, borrows `has_3_neighbors`, returns an owned `BoolMatrix`.
let births = is_dead & &has_3_neighbors;
// `survives | births`:
// Demonstrates element-wise OR (`Matrix<bool> | Matrix<bool>`).
// Consumes both operands, returns an owned `BoolMatrix`.
let next_frame_game = survives | births;
next_frame_game
}
pub fn generate_glider(board: &mut BoolMatrix, board_size: usize) {
// Initialize with a Glider pattern.
// It demonstrates how to set specific cells in the matrix.
// This demonstrates `IndexMut` for `current_board[(r, c)] = true;`.
let mut rng = rand::rng();
let r_offset = rng.random_range(0..(board_size - 3));
let c_offset = rng.random_range(0..(board_size - 3));
if board.rows() >= r_offset + 3 && board.cols() >= c_offset + 3 {
board[(r_offset + 0, c_offset + 1)] = true;
board[(r_offset + 1, c_offset + 2)] = true;
board[(r_offset + 2, c_offset + 0)] = true;
board[(r_offset + 2, c_offset + 1)] = true;
board[(r_offset + 2, c_offset + 2)] = true;
}
}
pub fn generate_pulsar(board: &mut BoolMatrix, board_size: usize) {
// Initialize with a Pulsar pattern.
// This demonstrates how to set specific cells in the matrix.
// This demonstrates `IndexMut` for `current_board[(r, c)] = true;`.
let mut rng = rand::rng();
let r_offset = rng.random_range(0..(board_size - 17));
let c_offset = rng.random_range(0..(board_size - 17));
if board.rows() >= r_offset + 17 && board.cols() >= c_offset + 17 {
let pulsar_coords = [
(2, 4),
(2, 5),
(2, 6),
(2, 10),
(2, 11),
(2, 12),
(4, 2),
(4, 7),
(4, 9),
(4, 14),
(5, 2),
(5, 7),
(5, 9),
(5, 14),
(6, 2),
(6, 7),
(6, 9),
(6, 14),
(7, 4),
(7, 5),
(7, 6),
(7, 10),
(7, 11),
(7, 12),
];
for &(dr, dc) in pulsar_coords.iter() {
board[(r_offset + dr, c_offset + dc)] = true;
}
}
}
pub fn detect_stable_state(
current_board: &BoolMatrix,
previous_board_state: &Option<BoolMatrix>,
) -> bool {
if let Some(ref prev_board) = previous_board_state {
// `*prev_board == current_board` demonstrates `PartialEq` for `Matrix`.
return *prev_board == *current_board;
}
false
}
pub fn hash_board(board: &BoolMatrix, primes: Vec<i32>) -> usize {
let board_ints_vec = board
.data()
.iter()
.map(|&cell| if cell { 1 } else { 0 })
.collect::<Vec<i32>>();
let ints_board = Matrix::from_vec(board_ints_vec, board.rows(), board.cols());
let primes_board = Matrix::from_vec(primes, ints_board.rows(), ints_board.cols());
let result = ints_board * primes_board;
let result: i32 = result.data().iter().sum();
result as usize
}
pub fn detect_repeating_state(board_hashes: &mut Vec<usize>) -> bool {
// so - detect alternating states. if 0==2, 1==3, 2==4, 3==5, 4==6, 5==7
if board_hashes.len() < 4 {
return false;
}
let mut result = false;
if (board_hashes[0] == board_hashes[2]) && (board_hashes[0] == board_hashes[2]) {
result = true;
}
// remove the 0th item
board_hashes.remove(0);
result
}
pub fn add_simulated_activity(current_board: &mut BoolMatrix, board_size: usize) {
for _ in 0..20 {
generate_glider(current_board, board_size);
}
// Generate a Pulsar pattern
for _ in 0..10 {
generate_pulsar(current_board, board_size);
}
}
// generate prime numbers
pub fn generate_primes(n: i32) -> Vec<i32> {
// I want to generate the first n primes
let mut primes = Vec::new();
let mut count = 0;
let mut num = 2; // Start checking for primes from 2
while count < n {
let mut is_prime = true;
for i in 2..=((num as f64).sqrt() as i32) {
if num % i == 0 {
is_prime = false;
break;
}
}
if is_prime {
primes.push(num);
count += 1;
}
num += 1;
}
primes
}
// --- Tests from previous example (can be kept or adapted) ---
#[cfg(test)]
mod tests {
use super::*;
use rustframe::matrix::{BoolMatrix, BoolOps}; // Assuming BoolOps is available for .count()
#[test]
fn test_blinker_oscillator() {
let initial_data = vec![false, true, false, false, true, false, false, true, false];
let game1 = BoolMatrix::from_vec(initial_data.clone(), 3, 3);
let expected_frame2_data = vec![false, false, false, true, true, true, false, false, false];
let expected_game2 = BoolMatrix::from_vec(expected_frame2_data, 3, 3);
let game2 = game_of_life_next_frame(&game1);
assert_eq!(
game2.data(),
expected_game2.data(),
"Frame 1 to Frame 2 failed for blinker"
);
let expected_game3 = BoolMatrix::from_vec(initial_data, 3, 3);
let game3 = game_of_life_next_frame(&game2);
assert_eq!(
game3.data(),
expected_game3.data(),
"Frame 2 to Frame 3 failed for blinker"
);
}
#[test]
fn test_empty_board_remains_empty() {
let board_3x3_all_false = BoolMatrix::from_vec(vec![false; 9], 3, 3);
let next_frame = game_of_life_next_frame(&board_3x3_all_false);
assert_eq!(
next_frame.count(),
0,
"All-false board should result in all-false"
);
}
#[test]
fn test_zero_size_board() {
let board_0x0 = BoolMatrix::from_vec(vec![], 0, 0);
let next_frame = game_of_life_next_frame(&board_0x0);
assert_eq!(next_frame.rows(), 0);
assert_eq!(next_frame.cols(), 0);
assert!(
next_frame.data().is_empty(),
"0x0 board should result in 0x0 board"
);
}
#[test]
fn test_still_life_block() {
let block_data = vec![
true, true, false, false, true, true, false, false, false, false, false, false, false,
false, false, false,
];
let game_block = BoolMatrix::from_vec(block_data.clone(), 4, 4);
let next_frame_block = game_of_life_next_frame(&game_block);
assert_eq!(
next_frame_block.data(),
game_block.data(),
"Block still life should remain unchanged"
);
}
}

View File

@@ -1,3 +0,0 @@
pub mod models;
pub mod stats;

View File

@@ -1,135 +0,0 @@
use crate::matrix::{Matrix, SeriesOps};
pub fn sigmoid(x: &Matrix<f64>) -> Matrix<f64> {
x.map(|v| 1.0 / (1.0 + (-v).exp()))
}
pub fn dsigmoid(y: &Matrix<f64>) -> Matrix<f64> {
// derivative w.r.t. pre-activation; takes y = sigmoid(x)
y.map(|v| v * (1.0 - v))
}
pub fn relu(x: &Matrix<f64>) -> Matrix<f64> {
x.map(|v| if v > 0.0 { v } else { 0.0 })
}
pub fn drelu(x: &Matrix<f64>) -> Matrix<f64> {
x.map(|v| if v > 0.0 { 1.0 } else { 0.0 })
}
pub fn leaky_relu(x: &Matrix<f64>) -> Matrix<f64> {
x.map(|v| if v > 0.0 { v } else { 0.01 * v })
}
pub fn dleaky_relu(x: &Matrix<f64>) -> Matrix<f64> {
x.map(|v| if v > 0.0 { 1.0 } else { 0.01 })
}
mod tests {
use super::*;
// Helper function to round all elements in a matrix to n decimal places
fn _round_matrix(mat: &Matrix<f64>, decimals: u32) -> Matrix<f64> {
let factor = 10f64.powi(decimals as i32);
let rounded: Vec<f64> = mat
.to_vec()
.iter()
.map(|v| (v * factor).round() / factor)
.collect();
Matrix::from_vec(rounded, mat.rows(), mat.cols())
}
#[test]
fn test_sigmoid() {
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![0.26894142, 0.5, 0.73105858], 3, 1);
let result = sigmoid(&x);
assert_eq!(_round_matrix(&result, 6), _round_matrix(&expected, 6));
}
#[test]
fn test_sigmoid_edge_case() {
let x = Matrix::from_vec(vec![-1000.0, 0.0, 1000.0], 3, 1);
let expected = Matrix::from_vec(vec![0.0, 0.5, 1.0], 3, 1);
let result = sigmoid(&x);
for (r, e) in result.data().iter().zip(expected.data().iter()) {
assert!((r - e).abs() < 1e-6);
}
}
#[test]
fn test_relu() {
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
assert_eq!(relu(&x), expected);
}
#[test]
fn test_relu_edge_case() {
let x = Matrix::from_vec(vec![-1e-10, 0.0, 1e10], 3, 1);
let expected = Matrix::from_vec(vec![0.0, 0.0, 1e10], 3, 1);
assert_eq!(relu(&x), expected);
}
#[test]
fn test_dsigmoid() {
let y = Matrix::from_vec(vec![0.26894142, 0.5, 0.73105858], 3, 1);
let expected = Matrix::from_vec(vec![0.19661193, 0.25, 0.19661193], 3, 1);
let result = dsigmoid(&y);
assert_eq!(_round_matrix(&result, 6), _round_matrix(&expected, 6));
}
#[test]
fn test_dsigmoid_edge_case() {
let y = Matrix::from_vec(vec![0.0, 0.5, 1.0], 3, 1); // Assume these are outputs from sigmoid(x)
let expected = Matrix::from_vec(vec![0.0, 0.25, 0.0], 3, 1);
let result = dsigmoid(&y);
for (r, e) in result.data().iter().zip(expected.data().iter()) {
assert!((r - e).abs() < 1e-6);
}
}
#[test]
fn test_drelu() {
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
assert_eq!(drelu(&x), expected);
}
#[test]
fn test_drelu_edge_case() {
let x = Matrix::from_vec(vec![-1e-10, 0.0, 1e10], 3, 1);
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
assert_eq!(drelu(&x), expected);
}
#[test]
fn test_leaky_relu() {
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![-0.01, 0.0, 1.0], 3, 1);
assert_eq!(leaky_relu(&x), expected);
}
#[test]
fn test_leaky_relu_edge_case() {
let x = Matrix::from_vec(vec![-1e-10, 0.0, 1e10], 3, 1);
let expected = Matrix::from_vec(vec![-1e-12, 0.0, 1e10], 3, 1);
assert_eq!(leaky_relu(&x), expected);
}
#[test]
fn test_dleaky_relu() {
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![0.01, 0.01, 1.0], 3, 1);
assert_eq!(dleaky_relu(&x), expected);
}
#[test]
fn test_dleaky_relu_edge_case() {
let x = Matrix::from_vec(vec![-1e-10, 0.0, 1e10], 3, 1);
let expected = Matrix::from_vec(vec![0.01, 0.01, 1.0], 3, 1);
assert_eq!(dleaky_relu(&x), expected);
}
}

View File

@@ -1,524 +0,0 @@
use crate::compute::models::activations::{drelu, relu, sigmoid};
use crate::matrix::{Matrix, SeriesOps};
use rand::prelude::*;
/// Supported activation functions
#[derive(Clone)]
pub enum ActivationKind {
Relu,
Sigmoid,
Tanh,
}
impl ActivationKind {
/// Apply activation elementwise
pub fn forward(&self, z: &Matrix<f64>) -> Matrix<f64> {
match self {
ActivationKind::Relu => relu(z),
ActivationKind::Sigmoid => sigmoid(z),
ActivationKind::Tanh => z.map(|v| v.tanh()),
}
}
/// Compute elementwise derivative w.r.t. pre-activation z
pub fn derivative(&self, z: &Matrix<f64>) -> Matrix<f64> {
match self {
ActivationKind::Relu => drelu(z),
ActivationKind::Sigmoid => {
let s = sigmoid(z);
s.zip(&s, |si, sj| si * (1.0 - sj))
}
ActivationKind::Tanh => z.map(|v| 1.0 - v.tanh().powi(2)),
}
}
}
/// Weight initialization schemes
#[derive(Clone)]
pub enum InitializerKind {
/// Uniform(-limit .. limit)
Uniform(f64),
/// Xavier/Glorot uniform
Xavier,
/// He (Kaiming) uniform
He,
}
impl InitializerKind {
pub fn initialize(&self, rows: usize, cols: usize) -> Matrix<f64> {
let mut rng = rand::rng();
let fan_in = rows;
let fan_out = cols;
let limit = match self {
InitializerKind::Uniform(l) => *l,
InitializerKind::Xavier => (6.0 / (fan_in + fan_out) as f64).sqrt(),
InitializerKind::He => (2.0 / fan_in as f64).sqrt(),
};
let data = (0..rows * cols)
.map(|_| rng.random_range(-limit..limit))
.collect::<Vec<_>>();
Matrix::from_vec(data, rows, cols)
}
}
/// Supported losses
#[derive(Clone)]
pub enum LossKind {
/// Mean Squared Error: L = 1/m * sum((y_hat - y)^2)
MSE,
/// Binary Cross-Entropy: L = -1/m * sum(y*log(y_hat) + (1-y)*log(1-y_hat))
BCE,
}
impl LossKind {
/// Compute gradient dL/dy_hat (before applying activation derivative)
pub fn gradient(&self, y_hat: &Matrix<f64>, y: &Matrix<f64>) -> Matrix<f64> {
let m = y.rows() as f64;
match self {
LossKind::MSE => (y_hat - y) * (2.0 / m),
LossKind::BCE => (y_hat - y) * (1.0 / m),
}
}
}
/// Configuration for a dense neural network
pub struct DenseNNConfig {
pub input_size: usize,
pub hidden_layers: Vec<usize>,
/// Must have length = hidden_layers.len() + 1
pub activations: Vec<ActivationKind>,
pub output_size: usize,
pub initializer: InitializerKind,
pub loss: LossKind,
pub learning_rate: f64,
pub epochs: usize,
}
/// A multi-layer perceptron with full configurability
pub struct DenseNN {
weights: Vec<Matrix<f64>>,
biases: Vec<Matrix<f64>>,
activations: Vec<ActivationKind>,
loss: LossKind,
lr: f64,
epochs: usize,
}
impl DenseNN {
/// Build a new DenseNN from the given configuration
pub fn new(config: DenseNNConfig) -> Self {
let mut sizes = vec![config.input_size];
sizes.extend(&config.hidden_layers);
sizes.push(config.output_size);
assert_eq!(
config.activations.len(),
sizes.len() - 1,
"Number of activation functions must match number of layers"
);
let mut weights = Vec::with_capacity(sizes.len() - 1);
let mut biases = Vec::with_capacity(sizes.len() - 1);
for i in 0..sizes.len() - 1 {
let w = config.initializer.initialize(sizes[i], sizes[i + 1]);
let b = Matrix::zeros(1, sizes[i + 1]);
weights.push(w);
biases.push(b);
}
DenseNN {
weights,
biases,
activations: config.activations,
loss: config.loss,
lr: config.learning_rate,
epochs: config.epochs,
}
}
/// Perform a full forward pass, returning pre-activations (z) and activations (a)
fn forward_full(&self, x: &Matrix<f64>) -> (Vec<Matrix<f64>>, Vec<Matrix<f64>>) {
let mut zs = Vec::with_capacity(self.weights.len());
let mut activs = Vec::with_capacity(self.weights.len() + 1);
activs.push(x.clone());
let mut a = x.clone();
for (i, (w, b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
let z = &a.dot(w) + &Matrix::repeat_rows(b, a.rows());
let a_next = self.activations[i].forward(&z);
zs.push(z);
activs.push(a_next.clone());
a = a_next;
}
(zs, activs)
}
/// Train the network on inputs X and targets Y
pub fn train(&mut self, x: &Matrix<f64>, y: &Matrix<f64>) {
let m = x.rows() as f64;
for _ in 0..self.epochs {
let (zs, activs) = self.forward_full(x);
let y_hat = activs.last().unwrap().clone();
// Initial delta (dL/dz) on output
let mut delta = match self.loss {
LossKind::BCE => self.loss.gradient(&y_hat, y),
LossKind::MSE => {
let grad = self.loss.gradient(&y_hat, y);
let dz = self
.activations
.last()
.unwrap()
.derivative(zs.last().unwrap());
grad.zip(&dz, |g, da| g * da)
}
};
// Backpropagate through layers
for l in (0..self.weights.len()).rev() {
let a_prev = &activs[l];
let dw = a_prev.transpose().dot(&delta) / m;
let db = Matrix::from_vec(delta.sum_vertical(), 1, delta.cols()) / m;
// Update weights & biases
self.weights[l] = &self.weights[l] - &(dw * self.lr);
self.biases[l] = &self.biases[l] - &(db * self.lr);
// Propagate delta to previous layer
if l > 0 {
let w_t = self.weights[l].transpose();
let da = self.activations[l - 1].derivative(&zs[l - 1]);
delta = delta.dot(&w_t).zip(&da, |d, a| d * a);
}
}
}
}
/// Run a forward pass and return the network's output
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
let mut a = x.clone();
for (i, (w, b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
let z = &a.dot(w) + &Matrix::repeat_rows(b, a.rows());
a = self.activations[i].forward(&z);
}
a
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::Matrix;
/// Compute MSE = 1/m * Σ (ŷ - y)²
fn mse_loss(y_hat: &Matrix<f64>, y: &Matrix<f64>) -> f64 {
let m = y.rows() as f64;
y_hat
.zip(y, |yh, yv| (yh - yv).powi(2))
.data()
.iter()
.sum::<f64>()
/ m
}
#[test]
fn test_predict_shape() {
let config = DenseNNConfig {
input_size: 1,
hidden_layers: vec![2],
activations: vec![ActivationKind::Relu, ActivationKind::Sigmoid],
output_size: 1,
initializer: InitializerKind::Uniform(0.1),
loss: LossKind::MSE,
learning_rate: 0.01,
epochs: 0,
};
let model = DenseNN::new(config);
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1);
let preds = model.predict(&x);
assert_eq!(preds.rows(), 3);
assert_eq!(preds.cols(), 1);
}
#[test]
#[should_panic(expected = "Number of activation functions must match number of layers")]
fn test_invalid_activation_count() {
let config = DenseNNConfig {
input_size: 2,
hidden_layers: vec![3],
activations: vec![ActivationKind::Relu], // Only one activation for two layers
output_size: 1,
initializer: InitializerKind::Uniform(0.1),
loss: LossKind::MSE,
learning_rate: 0.01,
epochs: 0,
};
let _model = DenseNN::new(config);
}
#[test]
fn test_train_no_epochs_does_nothing() {
let config = DenseNNConfig {
input_size: 1,
hidden_layers: vec![2],
activations: vec![ActivationKind::Relu, ActivationKind::Sigmoid],
output_size: 1,
initializer: InitializerKind::Uniform(0.1),
loss: LossKind::MSE,
learning_rate: 0.01,
epochs: 0,
};
let mut model = DenseNN::new(config);
let x = Matrix::from_vec(vec![0.0, 1.0], 2, 1);
let y = Matrix::from_vec(vec![0.0, 1.0], 2, 1);
let before = model.predict(&x);
model.train(&x, &y);
let after = model.predict(&x);
for i in 0..before.rows() {
for j in 0..before.cols() {
// "prediction changed despite 0 epochs"
assert!((before[(i, j)] - after[(i, j)]).abs() < 1e-12);
}
}
}
#[test]
fn test_train_one_epoch_changes_predictions() {
// Single-layer sigmoid regression so gradients flow.
let config = DenseNNConfig {
input_size: 1,
hidden_layers: vec![],
activations: vec![ActivationKind::Sigmoid],
output_size: 1,
initializer: InitializerKind::Uniform(0.1),
loss: LossKind::MSE,
learning_rate: 1.0,
epochs: 1,
};
let mut model = DenseNN::new(config);
let x = Matrix::from_vec(vec![0.0, 1.0], 2, 1);
let y = Matrix::from_vec(vec![0.0, 1.0], 2, 1);
let before = model.predict(&x);
model.train(&x, &y);
let after = model.predict(&x);
// At least one of the two outputs must move by >ϵ
let mut moved = false;
for i in 0..before.rows() {
if (before[(i, 0)] - after[(i, 0)]).abs() > 1e-8 {
moved = true;
}
}
assert!(moved, "predictions did not change after 1 epoch");
}
#[test]
fn test_training_reduces_mse_loss() {
// Same singlelayer sigmoid setup; check loss goes down.
let config = DenseNNConfig {
input_size: 1,
hidden_layers: vec![],
activations: vec![ActivationKind::Sigmoid],
output_size: 1,
initializer: InitializerKind::Uniform(0.1),
loss: LossKind::MSE,
learning_rate: 1.0,
epochs: 10,
};
let mut model = DenseNN::new(config);
let x = Matrix::from_vec(vec![0.0, 1.0, 0.5], 3, 1);
let y = Matrix::from_vec(vec![0.0, 1.0, 0.5], 3, 1);
let before_preds = model.predict(&x);
let before_loss = mse_loss(&before_preds, &y);
model.train(&x, &y);
let after_preds = model.predict(&x);
let after_loss = mse_loss(&after_preds, &y);
// MSE did not decrease (before: {}, after: {})
assert!(after_loss < before_loss);
}
#[test]
fn test_activation_kind_forward_tanh() {
let input = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![-0.76159415595, 0.0, 0.76159415595], 3, 1);
let output = ActivationKind::Tanh.forward(&input);
for i in 0..input.rows() {
for j in 0..input.cols() {
// Tanh forward output mismatch at ({}, {})
assert!((output[(i, j)] - expected[(i, j)]).abs() < 1e-9);
}
}
}
#[test]
fn test_activation_kind_derivative_relu() {
let input = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
let output = ActivationKind::Relu.derivative(&input);
for i in 0..input.rows() {
for j in 0..input.cols() {
// "ReLU derivative output mismatch at ({}, {})"
assert!((output[(i, j)] - expected[(i, j)]).abs() < 1e-9);
}
}
}
#[test]
fn test_activation_kind_derivative_tanh() {
let input = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
let expected = Matrix::from_vec(vec![0.41997434161, 1.0, 0.41997434161], 3, 1); // 1 - tanh(x)^2
let output = ActivationKind::Tanh.derivative(&input);
for i in 0..input.rows() {
for j in 0..input.cols() {
// "Tanh derivative output mismatch at ({}, {})"
assert!((output[(i, j)] - expected[(i, j)]).abs() < 1e-9);
}
}
}
#[test]
fn test_initializer_kind_xavier() {
let rows = 10;
let cols = 20;
let initializer = InitializerKind::Xavier;
let matrix = initializer.initialize(rows, cols);
let limit = (6.0 / (rows + cols) as f64).sqrt();
assert_eq!(matrix.rows(), rows);
assert_eq!(matrix.cols(), cols);
for val in matrix.data() {
// Xavier initialized value out of range
assert!(*val >= -limit && *val <= limit);
}
}
#[test]
fn test_initializer_kind_he() {
let rows = 10;
let cols = 20;
let initializer = InitializerKind::He;
let matrix = initializer.initialize(rows, cols);
let limit = (2.0 / rows as f64).sqrt();
assert_eq!(matrix.rows(), rows);
assert_eq!(matrix.cols(), cols);
for val in matrix.data() {
// He initialized value out of range
assert!(*val >= -limit && *val <= limit);
}
}
#[test]
fn test_loss_kind_bce_gradient() {
let y_hat = Matrix::from_vec(vec![0.1, 0.9, 0.4], 3, 1);
let y = Matrix::from_vec(vec![0.0, 1.0, 0.5], 3, 1);
let expected_gradient = Matrix::from_vec(vec![0.1 / 3.0, -0.1 / 3.0, -0.1 / 3.0], 3, 1); // (y_hat - y) * (1.0 / m)
let output_gradient = LossKind::BCE.gradient(&y_hat, &y);
assert_eq!(output_gradient.rows(), expected_gradient.rows());
assert_eq!(output_gradient.cols(), expected_gradient.cols());
for i in 0..output_gradient.rows() {
for j in 0..output_gradient.cols() {
// BCE gradient output mismatch at ({}, {})
assert!((output_gradient[(i, j)] - expected_gradient[(i, j)]).abs() < 1e-9);
}
}
}
#[test]
fn test_training_reduces_bce_loss() {
// Single-layer sigmoid setup; check BCE loss goes down.
let config = DenseNNConfig {
input_size: 1,
hidden_layers: vec![],
activations: vec![ActivationKind::Sigmoid],
output_size: 1,
initializer: InitializerKind::Uniform(0.1),
loss: LossKind::BCE,
learning_rate: 1.0,
epochs: 10,
};
let mut model = DenseNN::new(config);
let x = Matrix::from_vec(vec![0.0, 1.0, 0.5], 3, 1);
let y = Matrix::from_vec(vec![0.0, 1.0, 0.5], 3, 1);
let before_preds = model.predict(&x);
// BCE loss calculation for testing
let before_loss = -1.0 / (y.rows() as f64)
* before_preds
.zip(&y, |yh, yv| yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln())
.data()
.iter()
.sum::<f64>();
model.train(&x, &y);
let after_preds = model.predict(&x);
let after_loss = -1.0 / (y.rows() as f64)
* after_preds
.zip(&y, |yh, yv| yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln())
.data()
.iter()
.sum::<f64>();
// BCE did not decrease (before: {}, after: {})
assert!(after_loss < before_loss,);
}
#[test]
fn test_train_backprop_delta_propagation() {
// Network with two layers to test delta propagation to previous layer (l > 0)
let config = DenseNNConfig {
input_size: 2,
hidden_layers: vec![3],
activations: vec![ActivationKind::Sigmoid, ActivationKind::Sigmoid],
output_size: 1,
initializer: InitializerKind::Uniform(0.1),
loss: LossKind::MSE,
learning_rate: 0.1,
epochs: 1,
};
let mut model = DenseNN::new(config);
// Store initial weights and biases to compare after training
let initial_weights_l0 = model.weights[0].clone();
let initial_biases_l0 = model.biases[0].clone();
let initial_weights_l1 = model.weights[1].clone();
let initial_biases_l1 = model.biases[1].clone();
let x = Matrix::from_vec(vec![0.1, 0.2, 0.3, 0.4], 2, 2);
let y = Matrix::from_vec(vec![0.5, 0.6], 2, 1);
model.train(&x, &y);
// Verify that weights and biases of both layers have changed,
// implying delta propagation occurred for l > 0
// Weights of first layer did not change, delta propagation might not have occurred
assert!(model.weights[0] != initial_weights_l0);
// Biases of first layer did not change, delta propagation might not have occurred
assert!(model.biases[0] != initial_biases_l0);
// Weights of second layer did not change
assert!(model.weights[1] != initial_weights_l1);
// Biases of second layer did not change
assert!(model.biases[1] != initial_biases_l1);
}
}

View File

@@ -1,230 +0,0 @@
use crate::matrix::Matrix;
use std::collections::HashMap;
/// A Gaussian Naive Bayes classifier.
///
/// # Parameters
/// - `var_smoothing`: Portion of the largest variance of all features to add to variances for stability.
/// - `use_unbiased_variance`: If `true`, uses Bessel's correction (dividing by (n-1)); otherwise divides by n.
///
pub struct GaussianNB {
// Distinct class labels
classes: Vec<f64>,
// Prior probabilities P(class)
priors: Vec<f64>,
// Feature means per class
means: Vec<Matrix<f64>>,
// Feature variances per class
variances: Vec<Matrix<f64>>,
// var_smoothing
eps: f64,
// flag for unbiased variance
use_unbiased: bool,
}
impl GaussianNB {
/// Create a new GaussianNB.
///
/// # Arguments
/// * `var_smoothing` - small float added to variances for numerical stability.
/// * `use_unbiased_variance` - whether to apply Bessel's correction (divide by n-1).
pub fn new(var_smoothing: f64, use_unbiased_variance: bool) -> Self {
Self {
classes: Vec::new(),
priors: Vec::new(),
means: Vec::new(),
variances: Vec::new(),
eps: var_smoothing,
use_unbiased: use_unbiased_variance,
}
}
/// Fit the model according to the training data `x` and labels `y`.
///
/// # Panics
/// Panics if `x` or `y` is empty, or if their dimensions disagree.
pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>) {
let m = x.rows();
let n = x.cols();
assert_eq!(y.rows(), m, "Row count of X and Y must match");
assert_eq!(y.cols(), 1, "Y must be a column vector");
if m == 0 || n == 0 {
panic!("Input matrix x or y is empty");
}
// Group sample indices by label
let mut groups: HashMap<u64, Vec<usize>> = HashMap::new();
for i in 0..m {
let label = y[(i, 0)];
let bits = label.to_bits();
groups.entry(bits).or_default().push(i);
}
assert!(!groups.is_empty(), "No class labels found in y"); //-- panicked earlier
// Extract and sort class labels
self.classes = groups.keys().cloned().map(f64::from_bits).collect();
self.classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
self.priors.clear();
self.means.clear();
self.variances.clear();
// Precompute max variance for smoothing scale
let mut max_var_feature = 0.0;
for j in 0..n {
let mut col_vals = Vec::with_capacity(m);
for i in 0..m {
col_vals.push(x[(i, j)]);
}
let mean_all = col_vals.iter().sum::<f64>() / m as f64;
let var_all = col_vals.iter().map(|v| (v - mean_all).powi(2)).sum::<f64>() / m as f64;
if var_all > max_var_feature {
max_var_feature = var_all;
}
}
let smoothing = self.eps * max_var_feature;
// Compute per-class statistics
for &c in &self.classes {
let idx = &groups[&c.to_bits()];
let count = idx.len();
// Prior
self.priors.push(count as f64 / m as f64);
let mut mean = Matrix::zeros(1, n);
let mut var = Matrix::zeros(1, n);
// Mean
for &i in idx {
for j in 0..n {
mean[(0, j)] += x[(i, j)];
}
}
for j in 0..n {
mean[(0, j)] /= count as f64;
}
// Variance
for &i in idx {
for j in 0..n {
let d = x[(i, j)] - mean[(0, j)];
var[(0, j)] += d * d;
}
}
let denom = if self.use_unbiased {
(count as f64 - 1.0).max(1.0)
} else {
count as f64
};
for j in 0..n {
var[(0, j)] = var[(0, j)] / denom + smoothing;
if var[(0, j)] <= 0.0 {
var[(0, j)] = smoothing;
}
}
self.means.push(mean);
self.variances.push(var);
}
}
/// Perform classification on an array of test vectors `x`.
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
let m = x.rows();
let n = x.cols();
let k = self.classes.len();
let mut preds = Matrix::zeros(m, 1);
let ln_2pi = (2.0 * std::f64::consts::PI).ln();
for i in 0..m {
let mut best = (0, f64::NEG_INFINITY);
for c_idx in 0..k {
let mut log_prob = self.priors[c_idx].ln();
for j in 0..n {
let diff = x[(i, j)] - self.means[c_idx][(0, j)];
let var = self.variances[c_idx][(0, j)];
log_prob += -0.5 * (diff * diff / var + var.ln() + ln_2pi);
}
if log_prob > best.1 {
best = (c_idx, log_prob);
}
}
preds[(i, 0)] = self.classes[best.0];
}
preds
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::Matrix;
#[test]
fn test_simple_two_class() {
// Simple dataset: one feature, two classes 0 and 1
// Class 0: values [1.0, 1.2, 0.8]
// Class 1: values [3.0, 3.2, 2.8]
let x = Matrix::from_vec(vec![1.0, 1.2, 0.8, 3.0, 3.2, 2.8], 6, 1);
let y = Matrix::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0], 6, 1);
let mut clf = GaussianNB::new(1e-9, false);
clf.fit(&x, &y);
let test = Matrix::from_vec(vec![1.1, 3.1], 2, 1);
let preds = clf.predict(&test);
assert_eq!(preds[(0, 0)], 0.0);
assert_eq!(preds[(1, 0)], 1.0);
}
#[test]
fn test_unbiased_variance() {
// Same as above but with unbiased variance
let x = Matrix::from_vec(vec![2.0, 2.2, 1.8, 4.0, 4.2, 3.8], 6, 1);
let y = Matrix::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0], 6, 1);
let mut clf = GaussianNB::new(1e-9, true);
clf.fit(&x, &y);
let test = Matrix::from_vec(vec![2.1, 4.1], 2, 1);
let preds = clf.predict(&test);
assert_eq!(preds[(0, 0)], 0.0);
assert_eq!(preds[(1, 0)], 1.0);
}
#[test]
#[should_panic]
fn test_empty_input() {
let x = Matrix::zeros(0, 0);
let y = Matrix::zeros(0, 1);
let mut clf = GaussianNB::new(1e-9, false);
clf.fit(&x, &y);
}
#[test]
#[should_panic = "Row count of X and Y must match"]
fn test_mismatched_rows() {
let x = Matrix::from_vec(vec![1.0, 2.0], 2, 1);
let y = Matrix::from_vec(vec![0.0], 1, 1);
let mut clf = GaussianNB::new(1e-9, false);
clf.fit(&x, &y);
}
#[test]
fn test_variance_smoothing_override_with_zero_smoothing() {
// Scenario: var_smoothing is 0, and a feature has zero variance within a class.
// This should trigger the `if var[(0, j)] <= 0.0 { var[(0, j)] = smoothing; }` line.
let x = Matrix::from_vec(vec![1.0, 1.0, 2.0], 3, 1); // Class 0: [1.0, 1.0], Class 1: [2.0]
let y = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
let mut clf = GaussianNB::new(0.0, false); // var_smoothing = 0.0
clf.fit(&x, &y);
// For class 0 (index 0 in clf.classes), the feature (index 0) had values [1.0, 1.0], so variance was 0.
// Since var_smoothing was 0, smoothing is 0.
// The line `var[(0, j)] = smoothing;` should have set the variance to 0.0.
let class_0_idx = clf.classes.iter().position(|&c| c == 0.0).unwrap();
assert_eq!(clf.variances[class_0_idx][(0, 0)], 0.0);
// For class 1 (index 1 in clf.classes), the feature (index 0) had value [2.0].
// Variance calculation for a single point results in 0.
// The if condition will be true, and var[(0, j)] will be set to smoothing (0.0).
let class_1_idx = clf.classes.iter().position(|&c| c == 1.0).unwrap();
assert_eq!(clf.variances[class_1_idx][(0, 0)], 0.0);
}
}

View File

@@ -1,364 +0,0 @@
use crate::compute::stats::mean_vertical;
use crate::matrix::Matrix;
use rand::rng;
use rand::seq::SliceRandom;
pub struct KMeans {
pub centroids: Matrix<f64>, // (k, n_features)
}
impl KMeans {
/// Fit with k clusters.
pub fn fit(x: &Matrix<f64>, k: usize, max_iter: usize, tol: f64) -> (Self, Vec<usize>) {
let m = x.rows();
let n = x.cols();
assert!(k <= m, "k must be ≤ number of samples");
// ----- initialise centroids -----
let mut centroids = Matrix::zeros(k, n);
if k > 0 && m > 0 {
// case for empty data
if k == 1 {
let mean = mean_vertical(x);
centroids.row_copy_from_slice(0, &mean.data()); // ideally, data.row(0), but thats the same
} else {
// For k > 1, pick k distinct rows at random
let mut rng = rng();
let mut indices: Vec<usize> = (0..m).collect();
indices.shuffle(&mut rng);
for c in 0..k {
centroids.row_copy_from_slice(c, &x.row(indices[c]));
}
}
}
let mut labels = vec![0usize; m];
let mut distances = vec![0.0f64; m];
for _iter in 0..max_iter {
let mut changed = false;
// ----- assignment step -----
for i in 0..m {
let sample_row = x.row(i);
let mut best = 0usize;
let mut best_dist_sq = f64::MAX;
for c in 0..k {
let centroid_row = centroids.row(c);
let dist_sq: f64 = sample_row
.iter()
.zip(centroid_row.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
if dist_sq < best_dist_sq {
best_dist_sq = dist_sq;
best = c;
}
}
distances[i] = best_dist_sq;
if labels[i] != best {
labels[i] = best;
changed = true;
}
}
// ----- update step -----
let mut new_centroids = Matrix::zeros(k, n);
let mut counts = vec![0usize; k];
for i in 0..m {
let c = labels[i];
counts[c] += 1;
for j in 0..n {
new_centroids[(c, j)] += x[(i, j)];
}
}
for c in 0..k {
if counts[c] == 0 {
// This cluster is empty. Re-initialize its centroid to the point
// furthest from its assigned centroid to prevent the cluster from dying.
let mut furthest_point_idx = 0;
let mut max_dist_sq = 0.0;
for (i, &dist) in distances.iter().enumerate() {
if dist > max_dist_sq {
max_dist_sq = dist;
furthest_point_idx = i;
}
}
for j in 0..n {
new_centroids[(c, j)] = x[(furthest_point_idx, j)];
}
// Ensure this point isn't chosen again for another empty cluster in the same iteration.
if m > 0 {
distances[furthest_point_idx] = 0.0;
}
} else {
// Normalize the centroid by the number of points in it.
for j in 0..n {
new_centroids[(c, j)] /= counts[c] as f64;
}
}
}
// ----- convergence test -----
if !changed {
centroids = new_centroids; // update before breaking
break; // assignments stable
}
let diff = &new_centroids - &centroids;
centroids = new_centroids; // Update for the next iteration
if tol > 0.0 {
let sq_diff = &diff * &diff;
let shift = sq_diff.data().iter().sum::<f64>().sqrt();
if shift < tol {
break;
}
}
}
(Self { centroids }, labels)
}
/// Predict nearest centroid for each sample.
pub fn predict(&self, x: &Matrix<f64>) -> Vec<usize> {
let m = x.rows();
let k = self.centroids.rows();
if m == 0 {
return Vec::new();
}
let mut labels = vec![0usize; m];
for i in 0..m {
let sample_row = x.row(i);
let mut best = 0usize;
let mut best_dist_sq = f64::MAX;
for c in 0..k {
let centroid_row = self.centroids.row(c);
let dist_sq: f64 = sample_row
.iter()
.zip(centroid_row.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
if dist_sq < best_dist_sq {
best_dist_sq = dist_sq;
best = c;
}
}
labels[i] = best;
}
labels
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_k_means_empty_cluster_reinit_centroid() {
// Try multiple times to increase the chance of hitting the empty cluster case
for _ in 0..20 {
let data = vec![0.0, 0.0, 0.0, 0.0, 10.0, 10.0];
let x = FloatMatrix::from_rows_vec(data, 3, 2);
let k = 2;
let max_iter = 10;
let tol = 1e-6;
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
// Check if any cluster is empty
let mut counts = vec![0; k];
for &label in &labels {
counts[label] += 1;
}
if counts.iter().any(|&c| c == 0) {
// Only check the property for clusters that are empty
let centroids = kmeans_model.centroids;
for c in 0..k {
if counts[c] == 0 {
let mut matches_data_point = false;
for i in 0..3 {
let dx = centroids[(c, 0)] - x[(i, 0)];
let dy = centroids[(c, 1)] - x[(i, 1)];
if dx.abs() < 1e-9 && dy.abs() < 1e-9 {
matches_data_point = true;
break;
}
}
assert!(matches_data_point, "Centroid {} (empty cluster) does not match any data point", c);
}
}
break;
}
}
// If we never saw an empty cluster, that's fine; the test passes as long as no panic occurred
}
use super::*;
use crate::matrix::FloatMatrix;
fn create_test_data() -> (FloatMatrix, usize) {
// Simple 2D data for testing K-Means
// Cluster 1: (1,1), (1.5,1.5)
// Cluster 2: (5,8), (8,8), (6,7)
let data = vec![
1.0, 1.0, // Sample 0
1.5, 1.5, // Sample 1
5.0, 8.0, // Sample 2
8.0, 8.0, // Sample 3
6.0, 7.0, // Sample 4
];
let x = FloatMatrix::from_rows_vec(data, 5, 2);
let k = 2;
(x, k)
}
// Helper for single cluster test with exact mean
fn create_simple_integer_data() -> FloatMatrix {
// Data points: (1,1), (2,2), (3,3)
FloatMatrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2)
}
#[test]
fn test_k_means_fit_predict_basic() {
let (x, k) = create_test_data();
let max_iter = 100;
let tol = 1e-6;
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
// Assertions for fit
assert_eq!(kmeans_model.centroids.rows(), k);
assert_eq!(kmeans_model.centroids.cols(), x.cols());
assert_eq!(labels.len(), x.rows());
// Check if labels are within expected range (0 to k-1)
for &label in &labels {
assert!(label < k);
}
// Predict with the same data
let predicted_labels = kmeans_model.predict(&x);
// The exact labels might vary due to random initialization,
// but the clustering should be consistent.
// We expect two clusters. Let's check if samples 0,1 are in one cluster
// and samples 2,3,4 are in another.
let cluster_0_members = vec![labels[0], labels[1]];
let cluster_1_members = vec![labels[2], labels[3], labels[4]];
// All members of cluster 0 should have the same label
assert_eq!(cluster_0_members[0], cluster_0_members[1]);
// All members of cluster 1 should have the same label
assert_eq!(cluster_1_members[0], cluster_1_members[1]);
assert_eq!(cluster_1_members[0], cluster_1_members[2]);
// The two clusters should have different labels
assert_ne!(cluster_0_members[0], cluster_1_members[0]);
// Check predicted labels are consistent with fitted labels
assert_eq!(labels, predicted_labels);
// Test with a new sample
let new_sample_data = vec![1.2, 1.3]; // Should be close to cluster 0
let new_sample = FloatMatrix::from_rows_vec(new_sample_data, 1, 2);
let new_sample_label = kmeans_model.predict(&new_sample)[0];
assert_eq!(new_sample_label, cluster_0_members[0]);
let new_sample_data_2 = vec![7.0, 7.5]; // Should be close to cluster 1
let new_sample_2 = FloatMatrix::from_rows_vec(new_sample_data_2, 1, 2);
let new_sample_label_2 = kmeans_model.predict(&new_sample_2)[0];
assert_eq!(new_sample_label_2, cluster_1_members[0]);
}
#[test]
fn test_k_means_fit_k_equals_m() {
// Test case where k (number of clusters) equals m (number of samples)
let (x, _) = create_test_data(); // 5 samples
let k = 5; // 5 clusters
let max_iter = 10;
let tol = 1e-6;
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
assert_eq!(kmeans_model.centroids.rows(), k);
assert_eq!(labels.len(), x.rows());
// Each sample should be its own cluster. Due to random init, labels
// might not be [0,1,2,3,4] but will be a permutation of it.
let mut sorted_labels = labels.clone();
sorted_labels.sort_unstable();
sorted_labels.dedup();
// Labels should all be unique when k==m
assert_eq!(sorted_labels.len(), k);
}
#[test]
#[should_panic(expected = "k must be ≤ number of samples")]
fn test_k_means_fit_k_greater_than_m() {
let (x, _) = create_test_data(); // 5 samples
let k = 6; // k > m
let max_iter = 10;
let tol = 1e-6;
let (_kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
}
#[test]
fn test_k_means_fit_single_cluster() {
// Test with k=1
let x = create_simple_integer_data(); // Use integer data
let k = 1;
let max_iter = 100;
let tol = 1e-6;
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
assert_eq!(kmeans_model.centroids.rows(), 1);
assert_eq!(labels.len(), x.rows());
// All labels should be 0
assert!(labels.iter().all(|&l| l == 0));
// Centroid should be the mean of all data points
let expected_centroid_x = x.column(0).iter().sum::<f64>() / x.rows() as f64;
let expected_centroid_y = x.column(1).iter().sum::<f64>() / x.rows() as f64;
assert!((kmeans_model.centroids[(0, 0)] - expected_centroid_x).abs() < 1e-9);
assert!((kmeans_model.centroids[(0, 1)] - expected_centroid_y).abs() < 1e-9);
}
#[test]
fn test_k_means_predict_empty_matrix() {
let (x, k) = create_test_data();
let max_iter = 10;
let tol = 1e-6;
let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
// The `Matrix` type not support 0xN or Nx0 matrices.
// test with a 0x0 matrix is a valid edge case.
let empty_x = FloatMatrix::from_rows_vec(vec![], 0, 0);
let predicted_labels = kmeans_model.predict(&empty_x);
assert!(predicted_labels.is_empty());
}
#[test]
fn test_k_means_predict_single_sample() {
let (x, k) = create_test_data();
let max_iter = 10;
let tol = 1e-6;
let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
let single_sample = FloatMatrix::from_rows_vec(vec![1.1, 1.2], 1, 2);
let predicted_label = kmeans_model.predict(&single_sample);
assert_eq!(predicted_label.len(), 1);
assert!(predicted_label[0] < k);
}
}

View File

@@ -1,54 +0,0 @@
use crate::matrix::{Matrix, SeriesOps};
pub struct LinReg {
w: Matrix<f64>, // shape (n_features, 1)
b: f64,
}
impl LinReg {
pub fn new(n_features: usize) -> Self {
Self {
w: Matrix::from_vec(vec![0.0; n_features], n_features, 1),
b: 0.0,
}
}
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
// X.dot(w) + b
x.dot(&self.w) + self.b
}
pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>, lr: f64, epochs: usize) {
let m = x.rows() as f64;
for _ in 0..epochs {
let y_hat = self.predict(x);
let err = &y_hat - y; // shape (m,1)
// grads
let grad_w = x.transpose().dot(&err) * (2.0 / m); // (n,1)
let grad_b = (2.0 / m) * err.sum_vertical().iter().sum::<f64>();
// update
self.w = &self.w - &(grad_w * lr);
self.b -= lr * grad_b;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linreg_fit_predict() {
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
let y = Matrix::from_vec(vec![2.0, 3.0, 4.0, 5.0], 4, 1);
let mut model = LinReg::new(1);
model.fit(&x, &y, 0.01, 10000);
let preds = model.predict(&x);
assert!((preds[(0, 0)] - 2.0).abs() < 1e-2);
assert!((preds[(1, 0)] - 3.0).abs() < 1e-2);
assert!((preds[(2, 0)] - 4.0).abs() < 1e-2);
assert!((preds[(3, 0)] - 5.0).abs() < 1e-2);
}
}

View File

@@ -1,55 +0,0 @@
use crate::compute::models::activations::sigmoid;
use crate::matrix::{Matrix, SeriesOps};
pub struct LogReg {
w: Matrix<f64>,
b: f64,
}
impl LogReg {
pub fn new(n_features: usize) -> Self {
Self {
w: Matrix::zeros(n_features, 1),
b: 0.0,
}
}
pub fn predict_proba(&self, x: &Matrix<f64>) -> Matrix<f64> {
sigmoid(&(x.dot(&self.w) + self.b)) // σ(Xw + b)
}
pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>, lr: f64, epochs: usize) {
let m = x.rows() as f64;
for _ in 0..epochs {
let p = self.predict_proba(x); // shape (m,1)
let err = &p - y; // derivative of BCE wrt pre-sigmoid
let grad_w = x.transpose().dot(&err) / m;
let grad_b = err.sum_vertical().iter().sum::<f64>() / m;
self.w = &self.w - &(grad_w * lr);
self.b -= lr * grad_b;
}
}
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
self.predict_proba(x)
.map(|p| if p >= 0.5 { 1.0 } else { 0.0 })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_logreg_fit_predict() {
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
let y = Matrix::from_vec(vec![0.0, 0.0, 1.0, 1.0], 4, 1);
let mut model = LogReg::new(1);
model.fit(&x, &y, 0.01, 10000);
let preds = model.predict(&x);
assert_eq!(preds[(0, 0)], 0.0);
assert_eq!(preds[(1, 0)], 0.0);
assert_eq!(preds[(2, 0)], 1.0);
assert_eq!(preds[(3, 0)], 1.0);
}
}

View File

@@ -1,7 +0,0 @@
pub mod activations;
pub mod dense_nn;
pub mod gaussian_nb;
pub mod k_means;
pub mod linreg;
pub mod logreg;
pub mod pca;

View File

@@ -1,114 +0,0 @@
use crate::compute::stats::correlation::covariance_matrix;
use crate::compute::stats::descriptive::mean_vertical;
use crate::matrix::{Axis, Matrix, SeriesOps};
/// Returns the `n_components` principal axes (rows) and the centred data's mean.
pub struct PCA {
pub components: Matrix<f64>, // (n_components, n_features)
pub mean: Matrix<f64>, // (1, n_features)
}
impl PCA {
pub fn fit(x: &Matrix<f64>, n_components: usize, _iters: usize) -> Self {
let mean = mean_vertical(x); // Mean of each feature (column)
let broadcasted_mean = mean.broadcast_row_to_target_shape(x.rows(), x.cols());
let centered_data = x.zip(&broadcasted_mean, |x_i, mean_i| x_i - mean_i);
let covariance_matrix = covariance_matrix(&centered_data, Axis::Col); // Covariance between features
let mut components = Matrix::zeros(n_components, x.cols());
for i in 0..n_components {
if i < covariance_matrix.rows() {
components.row_copy_from_slice(i, &covariance_matrix.row(i));
} else {
break;
}
}
PCA { components, mean }
}
/// Project new data on the learned axes.
pub fn transform(&self, x: &Matrix<f64>) -> Matrix<f64> {
let broadcasted_mean = self.mean.broadcast_row_to_target_shape(x.rows(), x.cols());
let centered_data = x.zip(&broadcasted_mean, |x_i, mean_i| x_i - mean_i);
centered_data.matrix_mul(&self.components.transpose())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::Matrix;
const EPSILON: f64 = 1e-8;
#[test]
fn test_pca_basic() {
// Simple 2D data, points along y=x line
// Data:
// 1.0, 1.0
// 2.0, 2.0
// 3.0, 3.0
let data = Matrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2);
let (_n_samples, _n_features) = data.shape();
let pca = PCA::fit(&data, 1, 0); // n_components = 1, iters is unused
println!("Data shape: {:?}", data.shape());
println!("PCA mean shape: {:?}", pca.mean.shape());
println!("PCA components shape: {:?}", pca.components.shape());
// Expected mean: (2.0, 2.0)
assert!((pca.mean.get(0, 0) - 2.0).abs() < EPSILON);
assert!((pca.mean.get(0, 1) - 2.0).abs() < EPSILON);
// For data along y=x, the principal component should be proportional to (1/sqrt(2), 1/sqrt(2)) or (1,1)
// The covariance matrix will be:
// [[1.0, 1.0],
// [1.0, 1.0]]
// The principal component (eigenvector) will be (0.707, 0.707) or (-0.707, -0.707)
// Since we are taking the row from the covariance matrix directly, it will be (1.0, 1.0)
assert!((pca.components.get(0, 0) - 1.0).abs() < EPSILON);
assert!((pca.components.get(0, 1) - 1.0).abs() < EPSILON);
// Test transform
// Centered data:
// -1.0, -1.0
// 0.0, 0.0
// 1.0, 1.0
// Projected: (centered_data * components.transpose())
// (-1.0 * 1.0 + -1.0 * 1.0) = -2.0
// ( 0.0 * 1.0 + 0.0 * 1.0) = 0.0
// ( 1.0 * 1.0 + 1.0 * 1.0) = 2.0
let transformed_data = pca.transform(&data);
assert_eq!(transformed_data.rows(), 3);
assert_eq!(transformed_data.cols(), 1);
assert!((transformed_data.get(0, 0) - -2.0).abs() < EPSILON);
assert!((transformed_data.get(1, 0) - 0.0).abs() < EPSILON);
assert!((transformed_data.get(2, 0) - 2.0).abs() < EPSILON);
}
#[test]
fn test_pca_fit_break_branch() {
// Data with 2 features
let data = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
let (_n_samples, n_features) = data.shape();
// Set n_components greater than n_features to trigger the break branch
let n_components_large = n_features + 1;
let pca = PCA::fit(&data, n_components_large, 0);
// The components matrix should be initialized with n_components_large rows,
// but only the first n_features rows should be copied from the covariance matrix.
// The remaining rows should be zeros.
assert_eq!(pca.components.rows(), n_components_large);
assert_eq!(pca.components.cols(), n_features);
// Verify that rows beyond n_features are all zeros
for i in n_features..n_components_large {
for j in 0..n_features {
assert!((pca.components.get(i, j) - 0.0).abs() < EPSILON);
}
}
}
}

View File

@@ -1,251 +0,0 @@
use crate::compute::stats::{mean, mean_horizontal, mean_vertical, stddev};
use crate::matrix::{Axis, Matrix, SeriesOps};
/// Population covariance between two equally-sized matrices (flattened)
pub fn covariance(x: &Matrix<f64>, y: &Matrix<f64>) -> f64 {
assert_eq!(x.rows(), y.rows());
assert_eq!(x.cols(), y.cols());
let n = (x.rows() * x.cols()) as f64;
let mean_x = mean(x);
let mean_y = mean(y);
x.data()
.iter()
.zip(y.data().iter())
.map(|(&a, &b)| (a - mean_x) * (b - mean_y))
.sum::<f64>()
/ n
}
fn _covariance_axis(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
match axis {
Axis::Row => {
// Covariance between each pair of columns → cols x cols
let num_rows = x.rows() as f64;
let means = mean_vertical(x); // 1 x cols
let p = x.cols();
let mut data = vec![0.0; p * p];
for i in 0..p {
let mu_i = means.get(0, i);
for j in 0..p {
let mu_j = means.get(0, j);
let mut sum = 0.0;
for r in 0..x.rows() {
let d_i = x.get(r, i) - mu_i;
let d_j = x.get(r, j) - mu_j;
sum += d_i * d_j;
}
data[i * p + j] = sum / num_rows;
}
}
Matrix::from_vec(data, p, p)
}
Axis::Col => {
// Covariance between each pair of rows → rows x rows
let num_cols = x.cols() as f64;
let means = mean_horizontal(x); // rows x 1
let n = x.rows();
let mut data = vec![0.0; n * n];
for i in 0..n {
let mu_i = means.get(i, 0);
for j in 0..n {
let mu_j = means.get(j, 0);
let mut sum = 0.0;
for c in 0..x.cols() {
let d_i = x.get(i, c) - mu_i;
let d_j = x.get(j, c) - mu_j;
sum += d_i * d_j;
}
data[i * n + j] = sum / num_cols;
}
}
Matrix::from_vec(data, n, n)
}
}
}
/// Covariance between columns (i.e. across rows)
pub fn covariance_vertical(x: &Matrix<f64>) -> Matrix<f64> {
_covariance_axis(x, Axis::Row)
}
/// Covariance between rows (i.e. across columns)
pub fn covariance_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
_covariance_axis(x, Axis::Col)
}
/// Calculates the covariance matrix of the input data.
/// Assumes input `x` is (n_samples, n_features).
pub fn covariance_matrix(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
let (n_samples, n_features) = x.shape();
let centered_data = match axis {
Axis::Col => {
let mean_matrix = mean_vertical(x); // 1 x n_features
x.zip(
&mean_matrix.broadcast_row_to_target_shape(n_samples, n_features),
|val, m| val - m,
)
}
Axis::Row => {
let mean_matrix = mean_horizontal(x); // n_samples x 1
// Manually create a matrix by broadcasting the column vector across columns
let mut broadcasted_mean = Matrix::zeros(n_samples, n_features);
for r in 0..n_samples {
let mean_val = mean_matrix.get(r, 0);
for c in 0..n_features {
*broadcasted_mean.get_mut(r, c) = *mean_val;
}
}
x.zip(&broadcasted_mean, |val, m| val - m)
}
};
// Calculate covariance matrix: (X_centered^T * X_centered) / (n_samples - 1)
// If x is (n_samples, n_features), then centered_data is (n_samples, n_features)
// centered_data.transpose() is (n_features, n_samples)
// Result is (n_features, n_features)
centered_data.transpose().matrix_mul(&centered_data) / (n_samples as f64 - 1.0)
}
pub fn pearson(x: &Matrix<f64>, y: &Matrix<f64>) -> f64 {
assert_eq!(x.rows(), y.rows());
assert_eq!(x.cols(), y.cols());
let cov = covariance(x, y);
let std_x = stddev(x);
let std_y = stddev(y);
if std_x == 0.0 || std_y == 0.0 {
return 0.0; // Avoid division by zero
}
cov / (std_x * std_y)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::Matrix;
const EPS: f64 = 1e-8;
#[test]
fn test_covariance_scalar_same_matrix() {
// M =
// 1,2
// 3,4
// mean = 2.5
let data = vec![1.0, 2.0, 3.0, 4.0];
let m = Matrix::from_vec(data.clone(), 2, 2);
// flatten M: [1,2,3,4], mean = 2.5
// cov(M,M) = variance of flatten = 1.25
let cov = covariance(&m, &m);
assert!((cov - 1.25).abs() < EPS);
}
#[test]
fn test_covariance_scalar_diff_matrix() {
// x =
// 1,2
// 3,4
// y = 2*x
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
let y = Matrix::from_vec(vec![2.0, 4.0, 6.0, 8.0], 2, 2);
// mean_x = 2.5, mean_y = 5.0
// cov = sum((xi-2.5)*(yi-5.0))/4 = 2.5
let cov_xy = covariance(&x, &y);
assert!((cov_xy - 2.5).abs() < EPS);
}
#[test]
fn test_covariance_vertical() {
// M =
// 1,2
// 3,4
// cols are [1,3] and [2,4], each var=1, cov=1
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
let cov_mat = covariance_vertical(&m);
// Expect 2x2 matrix of all 1.0
for i in 0..2 {
for j in 0..2 {
assert!((cov_mat.get(i, j) - 1.0).abs() < EPS);
}
}
}
#[test]
fn test_covariance_horizontal() {
// M =
// 1,2
// 3,4
// rows are [1,2] and [3,4], each var=0.25, cov=0.25
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
let cov_mat = covariance_horizontal(&m);
// Expect 2x2 matrix of all 0.25
for i in 0..2 {
for j in 0..2 {
assert!((cov_mat.get(i, j) - 0.25).abs() < EPS);
}
}
}
#[test]
fn test_covariance_matrix_vertical() {
// Test with a simple 2x2 matrix
// M =
// 1, 2
// 3, 4
// Expected covariance matrix (vertical, i.e., between columns):
// Col1: [1, 3], mean = 2
// Col2: [2, 4], mean = 3
// Cov(Col1, Col1) = ((1-2)^2 + (3-2)^2) / (2-1) = (1+1)/1 = 2
// Cov(Col2, Col2) = ((2-3)^2 + (4-3)^2) / (2-1) = (1+1)/1 = 2
// Cov(Col1, Col2) = ((1-2)*(2-3) + (3-2)*(4-3)) / (2-1) = ((-1)*(-1) + (1)*(1))/1 = (1+1)/1 = 2
// Cov(Col2, Col1) = 2
// Expected:
// 2, 2
// 2, 2
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
let cov_mat = covariance_matrix(&m, Axis::Col);
assert!((cov_mat.get(0, 0) - 2.0).abs() < EPS);
assert!((cov_mat.get(0, 1) - 2.0).abs() < EPS);
assert!((cov_mat.get(1, 0) - 2.0).abs() < EPS);
assert!((cov_mat.get(1, 1) - 2.0).abs() < EPS);
}
#[test]
fn test_covariance_matrix_horizontal() {
// Test with a simple 2x2 matrix
// M =
// 1, 2
// 3, 4
// Expected covariance matrix (horizontal, i.e., between rows):
// Row1: [1, 2], mean = 1.5
// Row2: [3, 4], mean = 3.5
// Cov(Row1, Row1) = ((1-1.5)^2 + (2-1.5)^2) / (2-1) = (0.25+0.25)/1 = 0.5
// Cov(Row2, Row2) = ((3-3.5)^2 + (4-3.5)^2) / (2-1) = (0.25+0.25)/1 = 0.5
// Cov(Row1, Row2) = ((1-1.5)*(3-3.5) + (2-1.5)*(4-3.5)) / (2-1) = ((-0.5)*(-0.5) + (0.5)*(0.5))/1 = (0.25+0.25)/1 = 0.5
// Cov(Row2, Row1) = 0.5
// Expected:
// 0.5, -0.5
// -0.5, 0.5
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
let cov_mat = covariance_matrix(&m, Axis::Row);
assert!((cov_mat.get(0, 0) - 0.5).abs() < EPS);
assert!((cov_mat.get(0, 1) - (-0.5)).abs() < EPS);
assert!((cov_mat.get(1, 0) - (-0.5)).abs() < EPS);
assert!((cov_mat.get(1, 1) - 0.5).abs() < EPS);
}
}

View File

@@ -1,390 +0,0 @@
use crate::matrix::{Axis, Matrix, SeriesOps};
pub fn mean(x: &Matrix<f64>) -> f64 {
x.data().iter().sum::<f64>() / (x.rows() * x.cols()) as f64
}
pub fn mean_vertical(x: &Matrix<f64>) -> Matrix<f64> {
let m = x.rows() as f64;
Matrix::from_vec(x.sum_vertical(), 1, x.cols()) / m
}
pub fn mean_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
let n = x.cols() as f64;
Matrix::from_vec(x.sum_horizontal(), x.rows(), 1) / n
}
fn population_or_sample_variance(x: &Matrix<f64>, population: bool) -> f64 {
let m = (x.rows() * x.cols()) as f64;
let mean_val = mean(x);
x.data()
.iter()
.map(|&v| (v - mean_val).powi(2))
.sum::<f64>()
/ if population { m } else { m - 1.0 }
}
pub fn population_variance(x: &Matrix<f64>) -> f64 {
population_or_sample_variance(x, true)
}
pub fn sample_variance(x: &Matrix<f64>) -> f64 {
population_or_sample_variance(x, false)
}
fn _population_or_sample_variance_axis(
x: &Matrix<f64>,
axis: Axis,
population: bool,
) -> Matrix<f64> {
match axis {
Axis::Row => {
// Calculate variance for each column (vertical variance)
let num_rows = x.rows() as f64;
let mean_of_cols = mean_vertical(x); // 1 x cols matrix
let mut result_data = vec![0.0; x.cols()];
for c in 0..x.cols() {
let mean_val = mean_of_cols.get(0, c); // Mean for current column
let mut sum_sq_diff = 0.0;
for r in 0..x.rows() {
let diff = x.get(r, c) - mean_val;
sum_sq_diff += diff * diff;
}
result_data[c] = sum_sq_diff / (if population { num_rows } else { num_rows - 1.0 });
}
Matrix::from_vec(result_data, 1, x.cols())
}
Axis::Col => {
// Calculate variance for each row (horizontal variance)
let num_cols = x.cols() as f64;
let mean_of_rows = mean_horizontal(x); // rows x 1 matrix
let mut result_data = vec![0.0; x.rows()];
for r in 0..x.rows() {
let mean_val = mean_of_rows.get(r, 0); // Mean for current row
let mut sum_sq_diff = 0.0;
for c in 0..x.cols() {
let diff = x.get(r, c) - mean_val;
sum_sq_diff += diff * diff;
}
result_data[r] = sum_sq_diff / (if population { num_cols } else { num_cols - 1.0 });
}
Matrix::from_vec(result_data, x.rows(), 1)
}
}
}
pub fn population_variance_vertical(x: &Matrix<f64>) -> Matrix<f64> {
_population_or_sample_variance_axis(x, Axis::Row, true)
}
pub fn population_variance_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
_population_or_sample_variance_axis(x, Axis::Col, true)
}
pub fn sample_variance_vertical(x: &Matrix<f64>) -> Matrix<f64> {
_population_or_sample_variance_axis(x, Axis::Row, false)
}
pub fn sample_variance_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
_population_or_sample_variance_axis(x, Axis::Col, false)
}
pub fn stddev(x: &Matrix<f64>) -> f64 {
population_variance(x).sqrt()
}
pub fn stddev_vertical(x: &Matrix<f64>) -> Matrix<f64> {
population_variance_vertical(x).map(|v| v.sqrt())
}
pub fn stddev_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
population_variance_horizontal(x).map(|v| v.sqrt())
}
pub fn median(x: &Matrix<f64>) -> f64 {
let mut data = x.data().to_vec();
data.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mid = data.len() / 2;
if data.len() % 2 == 0 {
(data[mid - 1] + data[mid]) / 2.0
} else {
data[mid]
}
}
fn _median_axis(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
let mx = match axis {
Axis::Col => x.clone(),
Axis::Row => x.transpose(),
};
let mut result = Vec::with_capacity(mx.cols());
for c in 0..mx.cols() {
let mut col = mx.column(c).to_vec();
col.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mid = col.len() / 2;
if col.len() % 2 == 0 {
result.push((col[mid - 1] + col[mid]) / 2.0);
} else {
result.push(col[mid]);
}
}
let (r, c) = match axis {
Axis::Col => (1, mx.cols()),
Axis::Row => (mx.cols(), 1),
};
Matrix::from_vec(result, r, c)
}
pub fn median_vertical(x: &Matrix<f64>) -> Matrix<f64> {
_median_axis(x, Axis::Col)
}
pub fn median_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
_median_axis(x, Axis::Row)
}
pub fn percentile(x: &Matrix<f64>, p: f64) -> f64 {
if p < 0.0 || p > 100.0 {
panic!("Percentile must be between 0 and 100");
}
let mut data = x.data().to_vec();
data.sort_by(|a, b| a.partial_cmp(b).unwrap());
let index = ((p / 100.0) * (data.len() as f64 - 1.0)).round() as usize;
data[index]
}
fn _percentile_axis(x: &Matrix<f64>, p: f64, axis: Axis) -> Matrix<f64> {
if p < 0.0 || p > 100.0 {
panic!("Percentile must be between 0 and 100");
}
let mx: Matrix<f64> = match axis {
Axis::Col => x.clone(),
Axis::Row => x.transpose(),
};
let mut result = Vec::with_capacity(mx.cols());
for c in 0..mx.cols() {
let mut col = mx.column(c).to_vec();
col.sort_by(|a, b| a.partial_cmp(b).unwrap());
let index = ((p / 100.0) * (col.len() as f64 - 1.0)).round() as usize;
result.push(col[index]);
}
let (r, c) = match axis {
Axis::Col => (1, mx.cols()),
Axis::Row => (mx.cols(), 1),
};
Matrix::from_vec(result, r, c)
}
pub fn percentile_vertical(x: &Matrix<f64>, p: f64) -> Matrix<f64> {
_percentile_axis(x, p, Axis::Col)
}
pub fn percentile_horizontal(x: &Matrix<f64>, p: f64) -> Matrix<f64> {
_percentile_axis(x, p, Axis::Row)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::Matrix;
const EPSILON: f64 = 1e-8;
#[test]
fn test_descriptive_stats_regular_values() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let x = Matrix::from_vec(data, 1, 5);
// Mean
assert!((mean(&x) - 3.0).abs() < EPSILON);
// Variance
assert!((population_variance(&x) - 2.0).abs() < EPSILON);
// Standard Deviation
assert!((stddev(&x) - 1.4142135623730951).abs() < EPSILON);
// Median
assert!((median(&x) - 3.0).abs() < EPSILON);
// Percentile
assert!((percentile(&x, 0.0) - 1.0).abs() < EPSILON);
assert!((percentile(&x, 25.0) - 2.0).abs() < EPSILON);
assert!((percentile(&x, 50.0) - 3.0).abs() < EPSILON);
assert!((percentile(&x, 75.0) - 4.0).abs() < EPSILON);
assert!((percentile(&x, 100.0) - 5.0).abs() < EPSILON);
let data_even = vec![1.0, 2.0, 3.0, 4.0];
let x_even = Matrix::from_vec(data_even, 1, 4);
assert!((median(&x_even) - 2.5).abs() < EPSILON);
}
#[test]
fn test_descriptive_stats_outlier() {
let data = vec![1.0, 2.0, 3.0, 4.0, 100.0];
let x = Matrix::from_vec(data, 1, 5);
// Mean should be heavily affected by outlier
assert!((mean(&x) - 22.0).abs() < EPSILON);
// Variance should be heavily affected by outlier
assert!((population_variance(&x) - 1522.0).abs() < EPSILON);
// Standard Deviation should be heavily affected by outlier
assert!((stddev(&x) - 39.0128183970461).abs() < EPSILON);
// Median should be robust to outlier
assert!((median(&x) - 3.0).abs() < EPSILON);
}
#[test]
#[should_panic(expected = "Percentile must be between 0 and 100")]
fn test_percentile_panic_low() {
let data = vec![1.0, 2.0, 3.0];
let x = Matrix::from_vec(data, 1, 3);
percentile(&x, -1.0);
}
#[test]
#[should_panic(expected = "Percentile must be between 0 and 100")]
fn test_percentile_panic_high() {
let data = vec![1.0, 2.0, 3.0];
let x = Matrix::from_vec(data, 1, 3);
percentile(&x, 101.0);
}
#[test]
fn test_mean_vertical_horizontal() {
// 2x3 matrix:
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
let x = Matrix::from_vec(data, 2, 3);
// Vertical means (per column): [(1+4)/2, (2+5)/2, (3+6)/2]
let mv = mean_vertical(&x);
assert!((mv.get(0, 0) - 2.5).abs() < EPSILON);
assert!((mv.get(0, 1) - 3.5).abs() < EPSILON);
assert!((mv.get(0, 2) - 4.5).abs() < EPSILON);
// Horizontal means (per row): [(1+2+3)/3, (4+5+6)/3]
let mh = mean_horizontal(&x);
assert!((mh.get(0, 0) - 2.0).abs() < EPSILON);
assert!((mh.get(1, 0) - 5.0).abs() < EPSILON);
}
#[test]
fn test_variance_vertical_horizontal() {
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
let x = Matrix::from_vec(data, 2, 3);
// cols: {1,4}, {2,5}, {3,6} all give 2.25
let vv = population_variance_vertical(&x);
for c in 0..3 {
assert!((vv.get(0, c) - 2.25).abs() < EPSILON);
}
let vh = population_variance_horizontal(&x);
assert!((vh.get(0, 0) - (2.0 / 3.0)).abs() < EPSILON);
assert!((vh.get(1, 0) - (2.0 / 3.0)).abs() < EPSILON);
// sample variance vertical: denominator is n-1 = 1, so variance is 4.5
let svv = sample_variance_vertical(&x);
for c in 0..3 {
assert!((svv.get(0, c) - 4.5).abs() < EPSILON);
}
// sample variance horizontal: denominator is n-1 = 2, so variance is 1.0
let svh = sample_variance_horizontal(&x);
assert!((svh.get(0, 0) - 1.0).abs() < EPSILON);
assert!((svh.get(1, 0) - 1.0).abs() < EPSILON);
}
#[test]
fn test_stddev_vertical_horizontal() {
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
let x = Matrix::from_vec(data, 2, 3);
// Stddev is sqrt of variance
let sv = stddev_vertical(&x);
for c in 0..3 {
assert!((sv.get(0, c) - 1.5).abs() < EPSILON);
}
let sh = stddev_horizontal(&x);
// sqrt(2/3) ≈ 0.816497
let expected = (2.0 / 3.0 as f64).sqrt();
assert!((sh.get(0, 0) - expected).abs() < EPSILON);
assert!((sh.get(1, 0) - expected).abs() < EPSILON);
// sample stddev vertical: sqrt(4.5) ≈ 2.12132034
let ssv = sample_variance_vertical(&x).map(|v| v.sqrt());
for c in 0..3 {
assert!((ssv.get(0, c) - 2.1213203435596424).abs() < EPSILON);
}
// sample stddev horizontal: sqrt(1.0) = 1.0
let ssh = sample_variance_horizontal(&x).map(|v| v.sqrt());
assert!((ssh.get(0, 0) - 1.0).abs() < EPSILON);
assert!((ssh.get(1, 0) - 1.0).abs() < EPSILON);
}
#[test]
fn test_median_vertical_horizontal() {
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
let x = Matrix::from_vec(data, 2, 3);
let mv = median_vertical(&x).row(0);
let expected_v = vec![2.5, 3.5, 4.5];
assert_eq!(mv, expected_v, "{:?} expected: {:?}", expected_v, mv);
let mh = median_horizontal(&x).column(0).to_vec();
let expected_h = vec![2.0, 5.0];
assert_eq!(mh, expected_h, "{:?} expected: {:?}", expected_h, mh);
}
#[test]
fn test_percentile_vertical_horizontal() {
// vec of f64 values 1..24 as a 4x6 matrix
let data: Vec<f64> = (1..=24).map(|x| x as f64).collect();
let x = Matrix::from_vec(data, 4, 6);
// columns:
// 1, 5, 9, 13, 17, 21
// 2, 6, 10, 14, 18, 22
// 3, 7, 11, 15, 19, 23
// 4, 8, 12, 16, 20, 24
let er0 = vec![1., 5., 9., 13., 17., 21.];
let er50 = vec![3., 7., 11., 15., 19., 23.];
let er100 = vec![4., 8., 12., 16., 20., 24.];
assert_eq!(percentile_vertical(&x, 0.0).data(), er0);
assert_eq!(percentile_vertical(&x, 50.0).data(), er50);
assert_eq!(percentile_vertical(&x, 100.0).data(), er100);
let eh0 = vec![1., 2., 3., 4.];
let eh50 = vec![13., 14., 15., 16.];
let eh100 = vec![21., 22., 23., 24.];
assert_eq!(percentile_horizontal(&x, 0.0).data(), eh0);
assert_eq!(percentile_horizontal(&x, 50.0).data(), eh50);
assert_eq!(percentile_horizontal(&x, 100.0).data(), eh100);
}
#[test]
#[should_panic(expected = "Percentile must be between 0 and 100")]
fn test_percentile_out_of_bounds() {
let data = vec![1.0, 2.0, 3.0];
let x = Matrix::from_vec(data, 1, 3);
percentile(&x, -10.0); // Should panic
}
#[test]
#[should_panic(expected = "Percentile must be between 0 and 100")]
fn test_percentile_vertical_out_of_bounds() {
let m = Matrix::from_vec(vec![1.0, 2.0, 3.0], 1, 3);
let _ = percentile_vertical(&m, -0.1);
}
}

View File

@@ -1,382 +0,0 @@
use crate::matrix::{Matrix, SeriesOps};
use std::f64::consts::PI;
/// Approximation of the error function (Abramowitz & Stegun 7.1.26)
fn erf_func(x: f64) -> f64 {
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
// coefficients
let a1 = 0.254829592;
let a2 = -0.284496736;
let a3 = 1.421413741;
let a4 = -1.453152027;
let a5 = 1.061405429;
let p = 0.3275911;
let t = 1.0 / (1.0 + p * x);
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
sign * y
}
/// Approximation of the error function for matrices
pub fn erf(x: Matrix<f64>) -> Matrix<f64> {
x.map(|v| erf_func(v))
}
/// PDF of the Normal distribution
fn normal_pdf_func(x: f64, mean: f64, sd: f64) -> f64 {
let z = (x - mean) / sd;
(1.0 / (sd * (2.0 * PI).sqrt())) * (-0.5 * z * z).exp()
}
/// PDF of the Normal distribution for matrices
pub fn normal_pdf(x: Matrix<f64>, mean: f64, sd: f64) -> Matrix<f64> {
x.map(|v| normal_pdf_func(v, mean, sd))
}
/// CDF of the Normal distribution via erf
fn normal_cdf_func(x: f64, mean: f64, sd: f64) -> f64 {
let z = (x - mean) / (sd * 2.0_f64.sqrt());
0.5 * (1.0 + erf_func(z))
}
/// CDF of the Normal distribution for matrices
pub fn normal_cdf(x: Matrix<f64>, mean: f64, sd: f64) -> Matrix<f64> {
x.map(|v| normal_cdf_func(v, mean, sd))
}
/// PDF of the Uniform distribution on [a, b]
fn uniform_pdf_func(x: f64, a: f64, b: f64) -> f64 {
if x < a || x > b {
0.0
} else {
1.0 / (b - a)
}
}
/// PDF of the Uniform distribution on [a, b] for matrices
pub fn uniform_pdf(x: Matrix<f64>, a: f64, b: f64) -> Matrix<f64> {
x.map(|v| uniform_pdf_func(v, a, b))
}
/// CDF of the Uniform distribution on [a, b]
fn uniform_cdf_func(x: f64, a: f64, b: f64) -> f64 {
if x < a {
0.0
} else if x <= b {
(x - a) / (b - a)
} else {
1.0
}
}
/// CDF of the Uniform distribution on [a, b] for matrices
pub fn uniform_cdf(x: Matrix<f64>, a: f64, b: f64) -> Matrix<f64> {
x.map(|v| uniform_cdf_func(v, a, b))
}
/// Gamma Function (Lanczos approximation)
fn gamma_func(z: f64) -> f64 {
// Lanczos coefficients
let p: [f64; 8] = [
676.5203681218851,
-1259.1392167224028,
771.32342877765313,
-176.61502916214059,
12.507343278686905,
-0.13857109526572012,
9.9843695780195716e-6,
1.5056327351493116e-7,
];
if z < 0.5 {
PI / ((PI * z).sin() * gamma_func(1.0 - z))
} else {
let z = z - 1.0;
let mut x = 0.99999999999980993;
for (i, &pi) in p.iter().enumerate() {
x += pi / (z + (i as f64) + 1.0);
}
let t = z + p.len() as f64 - 0.5;
(2.0 * PI).sqrt() * t.powf(z + 0.5) * (-t).exp() * x
}
}
pub fn gamma(z: Matrix<f64>) -> Matrix<f64> {
z.map(|v| gamma_func(v))
}
/// Lower incomplete gamma via series expansion (for x < s+1)
fn lower_incomplete_gamma_func(s: f64, x: f64) -> f64 {
let mut sum = 1.0 / s;
let mut term = sum;
for n in 1..100 {
term *= x / (s + n as f64);
sum += term;
}
sum * x.powf(s) * (-x).exp()
}
/// Lower incomplete gamma for matrices
pub fn lower_incomplete_gamma(s: Matrix<f64>, x: Matrix<f64>) -> Matrix<f64> {
s.zip(&x, |s_val, x_val| lower_incomplete_gamma_func(s_val, x_val))
}
/// PDF of the Gamma distribution (shape k, scale θ)
fn gamma_pdf_func(x: f64, k: f64, theta: f64) -> f64 {
if x < 0.0 {
return 0.0;
}
let coef = 1.0 / (gamma_func(k) * theta.powf(k));
coef * x.powf(k - 1.0) * (-(x / theta)).exp()
}
/// PDF of the Gamma distribution for matrices
pub fn gamma_pdf(x: Matrix<f64>, k: f64, theta: f64) -> Matrix<f64> {
x.map(|v| gamma_pdf_func(v, k, theta))
}
/// CDF of the Gamma distribution via lower incomplete gamma
fn gamma_cdf_func(x: f64, k: f64, theta: f64) -> f64 {
if x < 0.0 {
return 0.0;
}
lower_incomplete_gamma_func(k, x / theta) / gamma_func(k)
}
/// CDF of the Gamma distribution for matrices
pub fn gamma_cdf(x: Matrix<f64>, k: f64, theta: f64) -> Matrix<f64> {
x.map(|v| gamma_cdf_func(v, k, theta))
}
/// Factorials and Combinations ///
/// Compute n! as f64 (works up to ~170 reliably)
fn factorial(n: u64) -> f64 {
(1..=n).map(|i| i as f64).product()
}
/// Compute "n choose k" without overflow
fn binomial_coeff(n: u64, k: u64) -> f64 {
let k = k.min(n - k);
let mut numer = 1.0;
let mut denom = 1.0;
for i in 0..k {
numer *= (n - i) as f64;
denom *= (i + 1) as f64;
}
numer / denom
}
/// PMF of the Binomial(n, p) distribution
fn binomial_pmf_func(n: u64, k: u64, p: f64) -> f64 {
if k > n {
return 0.0;
}
binomial_coeff(n, k) * p.powf(k as f64) * (1.0 - p).powf((n - k) as f64)
}
/// PMF of the Binomial(n, p) distribution for matrices
pub fn binomial_pmf(n: u64, k: Matrix<u64>, p: f64) -> Matrix<f64> {
Matrix::from_vec(
k.data()
.iter()
.map(|&v| binomial_pmf_func(n, v, p))
.collect::<Vec<f64>>(),
k.rows(),
k.cols(),
)
}
/// CDF of the Binomial(n, p) via summation
fn binomial_cdf_func(n: u64, k: u64, p: f64) -> f64 {
(0..=k).map(|i| binomial_pmf_func(n, i, p)).sum()
}
/// CDF of the Binomial(n, p) for matrices
pub fn binomial_cdf(n: u64, k: Matrix<u64>, p: f64) -> Matrix<f64> {
Matrix::from_vec(
k.data()
.iter()
.map(|&v| binomial_cdf_func(n, v, p))
.collect::<Vec<f64>>(),
k.rows(),
k.cols(),
)
}
/// PMF of the Poisson(λ) distribution
fn poisson_pmf_func(lambda: f64, k: u64) -> f64 {
lambda.powf(k as f64) * (-lambda).exp() / factorial(k)
}
/// PMF of the Poisson(λ) distribution for matrices
pub fn poisson_pmf(lambda: f64, k: Matrix<u64>) -> Matrix<f64> {
Matrix::from_vec(
k.data()
.iter()
.map(|&v| poisson_pmf_func(lambda, v))
.collect::<Vec<f64>>(),
k.rows(),
k.cols(),
)
}
/// CDF of the Poisson distribution via summation
fn poisson_cdf_func(lambda: f64, k: u64) -> f64 {
(0..=k).map(|i| poisson_pmf_func(lambda, i)).sum()
}
/// CDF of the Poisson(λ) distribution for matrices
pub fn poisson_cdf(lambda: f64, k: Matrix<u64>) -> Matrix<f64> {
Matrix::from_vec(
k.data()
.iter()
.map(|&v| poisson_cdf_func(lambda, v))
.collect::<Vec<f64>>(),
k.rows(),
k.cols(),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_math_funcs() {
// Test erf function
assert!((erf_func(0.0) - 0.0).abs() < 1e-7);
assert!((erf_func(1.0) - 0.8427007).abs() < 1e-7);
assert!((erf_func(-1.0) + 0.8427007).abs() < 1e-7);
// Test gamma function
assert!((gamma_func(1.0) - 1.0).abs() < 1e-7);
assert!((gamma_func(2.0) - 1.0).abs() < 1e-7);
assert!((gamma_func(3.0) - 2.0).abs() < 1e-7);
assert!((gamma_func(4.0) - 6.0).abs() < 1e-7);
assert!((gamma_func(5.0) - 24.0).abs() < 1e-7);
let z = 0.3;
let expected = PI / ((PI * z).sin() * gamma_func(1.0 - z));
assert!((gamma_func(z) - expected).abs() < 1e-7);
}
#[test]
fn test_math_matrix() {
let x = Matrix::filled(5, 5, 1.0);
let erf_result = erf(x.clone());
assert!((erf_result.data()[0] - 0.8427007).abs() < 1e-7);
let gamma_result = gamma(x);
assert!((gamma_result.data()[0] - 1.0).abs() < 1e-7);
}
#[test]
fn test_normal_funcs() {
assert!((normal_pdf_func(0.0, 0.0, 1.0) - 0.39894228).abs() < 1e-7);
assert!((normal_cdf_func(1.0, 0.0, 1.0) - 0.8413447).abs() < 1e-7);
}
#[test]
fn test_normal_matrix() {
let x = Matrix::filled(5, 5, 0.0);
let pdf = normal_pdf(x.clone(), 0.0, 1.0);
let cdf = normal_cdf(x, 0.0, 1.0);
assert!((pdf.data()[0] - 0.39894228).abs() < 1e-7);
assert!((cdf.data()[0] - 0.5).abs() < 1e-7);
}
#[test]
fn test_uniform_funcs() {
assert_eq!(uniform_pdf_func(0.5, 0.0, 1.0), 1.0);
assert_eq!(uniform_cdf_func(-1.0, 0.0, 1.0), 0.0);
assert_eq!(uniform_cdf_func(0.5, 0.0, 1.0), 0.5);
// x<a (or x>b) should return 0
assert_eq!(uniform_pdf_func(-0.5, 0.0, 1.0), 0.0);
assert_eq!(uniform_pdf_func(1.5, 0.0, 1.0), 0.0);
// for cdf x>a AND x>b should return 1
assert_eq!(uniform_cdf_func(1.5, 0.0, 1.0), 1.0);
assert_eq!(uniform_cdf_func(2.0, 0.0, 1.0), 1.0);
}
#[test]
fn test_uniform_matrix() {
let x = Matrix::filled(5, 5, 0.5);
let pdf = uniform_pdf(x.clone(), 0.0, 1.0);
let cdf = uniform_cdf(x, 0.0, 1.0);
assert_eq!(pdf.data()[0], 1.0);
assert_eq!(cdf.data()[0], 0.5);
}
#[test]
fn test_binomial_funcs() {
let pmf = binomial_pmf_func(5, 2, 0.5);
assert!((pmf - 0.3125).abs() < 1e-7);
let cdf = binomial_cdf_func(5, 2, 0.5);
assert!((cdf - (0.03125 + 0.15625 + 0.3125)).abs() < 1e-7);
let pmf_zero = binomial_pmf_func(5, 6, 0.5);
assert!(pmf_zero == 0.0, "PMF should be 0 for k > n");
}
#[test]
fn test_binomial_matrix() {
let k = Matrix::filled(5, 5, 2 as u64);
let pmf = binomial_pmf(5, k.clone(), 0.5);
let cdf = binomial_cdf(5, k, 0.5);
assert!((pmf.data()[0] - 0.3125).abs() < 1e-7);
assert!((cdf.data()[0] - (0.03125 + 0.15625 + 0.3125)).abs() < 1e-7);
}
#[test]
fn test_poisson_funcs() {
let pmf: f64 = poisson_pmf_func(3.0, 2);
assert!((pmf - (3.0_f64.powf(2.0) * (-3.0 as f64).exp() / 2.0)).abs() < 1e-7);
let cdf: f64 = poisson_cdf_func(3.0, 2);
assert!((cdf - (pmf + poisson_pmf_func(3.0, 0) + poisson_pmf_func(3.0, 1))).abs() < 1e-7);
}
#[test]
fn test_poisson_matrix() {
let k = Matrix::filled(5, 5, 2);
let pmf = poisson_pmf(3.0, k.clone());
let cdf = poisson_cdf(3.0, k);
assert!((pmf.data()[0] - (3.0_f64.powf(2.0) * (-3.0 as f64).exp() / 2.0)).abs() < 1e-7);
assert!(
(cdf.data()[0] - (pmf.data()[0] + poisson_pmf_func(3.0, 0) + poisson_pmf_func(3.0, 1)))
.abs()
< 1e-7
);
}
#[test]
fn test_gamma_funcs() {
// For k=1, θ=1 the Gamma(1,1) is Exp(1), so pdf(x)=e^-x
assert!((gamma_pdf_func(2.0, 1.0, 1.0) - (-2.0 as f64).exp()).abs() < 1e-7);
assert!((gamma_cdf_func(2.0, 1.0, 1.0) - (1.0 - (-2.0 as f64).exp())).abs() < 1e-7);
// <0 case
assert_eq!(gamma_pdf_func(-1.0, 1.0, 1.0), 0.0);
assert_eq!(gamma_cdf_func(-1.0, 1.0, 1.0), 0.0);
}
#[test]
fn test_gamma_matrix() {
let x = Matrix::filled(5, 5, 2.0);
let pdf = gamma_pdf(x.clone(), 1.0, 1.0);
let cdf = gamma_cdf(x, 1.0, 1.0);
assert!((pdf.data()[0] - (-2.0 as f64).exp()).abs() < 1e-7);
assert!((cdf.data()[0] - (1.0 - (-2.0 as f64).exp())).abs() < 1e-7);
}
#[test]
fn test_lower_incomplete_gamma() {
let s = Matrix::filled(5, 5, 2.0);
let x = Matrix::filled(5, 5, 1.0);
let expected = lower_incomplete_gamma_func(2.0, 1.0);
let result = lower_incomplete_gamma(s, x);
assert!((result.data()[0] - expected).abs() < 1e-7);
}
}

View File

@@ -1,131 +0,0 @@
use crate::matrix::{Matrix, SeriesOps};
use crate::compute::stats::{gamma_cdf, mean, sample_variance};
/// Two-sample t-test returning (t_statistic, p_value)
pub fn t_test(sample1: &Matrix<f64>, sample2: &Matrix<f64>) -> (f64, f64) {
let mean1 = mean(sample1);
let mean2 = mean(sample2);
let var1 = sample_variance(sample1);
let var2 = sample_variance(sample2);
let n1 = (sample1.rows() * sample1.cols()) as f64;
let n2 = (sample2.rows() * sample2.cols()) as f64;
let t_statistic = (mean1 - mean2) / ((var1 / n1 + var2 / n2).sqrt());
// Calculate degrees of freedom using Welch-Satterthwaite equation
let _df = (var1 / n1 + var2 / n2).powi(2)
/ ((var1 / n1).powi(2) / (n1 - 1.0) + (var2 / n2).powi(2) / (n2 - 1.0));
// Calculate p-value using t-distribution CDF (two-tailed)
let p_value = 0.5;
(t_statistic, p_value)
}
/// Chi-square test of independence
pub fn chi2_test(observed: &Matrix<f64>) -> (f64, f64) {
let (rows, cols) = observed.shape();
let row_sums: Vec<f64> = observed.sum_horizontal();
let col_sums: Vec<f64> = observed.sum_vertical();
let grand_total: f64 = observed.data().iter().sum();
let mut chi2_statistic: f64 = 0.0;
for i in 0..rows {
for j in 0..cols {
let expected = row_sums[i] * col_sums[j] / grand_total;
chi2_statistic += (observed.get(i, j) - expected).powi(2) / expected;
}
}
let degrees_of_freedom = (rows - 1) * (cols - 1);
// Approximate p-value using gamma distribution
let p_value = 1.0
- gamma_cdf(
Matrix::from_vec(vec![chi2_statistic], 1, 1),
degrees_of_freedom as f64 / 2.0,
1.0,
)
.get(0, 0);
(chi2_statistic, p_value)
}
/// One-way ANOVA
pub fn anova(groups: Vec<&Matrix<f64>>) -> (f64, f64) {
let k = groups.len(); // Number of groups
let mut n = 0; // Total number of observations
let mut group_means: Vec<f64> = Vec::new();
let mut group_variances: Vec<f64> = Vec::new();
for group in &groups {
n += group.rows() * group.cols();
group_means.push(mean(group));
group_variances.push(sample_variance(group));
}
let grand_mean: f64 = group_means.iter().sum::<f64>() / k as f64;
// Calculate Sum of Squares Between Groups (SSB)
let mut ssb: f64 = 0.0;
for i in 0..k {
ssb += (group_means[i] - grand_mean).powi(2) * (groups[i].rows() * groups[i].cols()) as f64;
}
// Calculate Sum of Squares Within Groups (SSW)
let mut ssw: f64 = 0.0;
for i in 0..k {
ssw += group_variances[i] * (groups[i].rows() * groups[i].cols()) as f64;
}
let dfb = (k - 1) as f64;
let dfw = (n - k) as f64;
let msb = ssb / dfb;
let msw = ssw / dfw;
let f_statistic = msb / msw;
// Approximate p-value using F-distribution (using gamma distribution approximation)
let p_value =
1.0 - gamma_cdf(Matrix::from_vec(vec![f_statistic], 1, 1), dfb / 2.0, 1.0).get(0, 0);
(f_statistic, p_value)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::Matrix;
const EPS: f64 = 1e-5;
#[test]
fn test_t_test() {
let sample1 = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
let sample2 = Matrix::from_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0], 1, 5);
let (t_statistic, p_value) = t_test(&sample1, &sample2);
assert!((t_statistic + 5.0).abs() < EPS);
assert!(p_value > 0.0 && p_value < 1.0);
}
#[test]
fn test_chi2_test() {
let observed = Matrix::from_vec(vec![12.0, 5.0, 8.0, 10.0], 2, 2);
let (chi2_statistic, p_value) = chi2_test(&observed);
assert!(chi2_statistic > 0.0);
assert!(p_value > 0.0 && p_value < 1.0);
}
#[test]
fn test_anova() {
let group1 = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
let group2 = Matrix::from_vec(vec![2.0, 3.0, 4.0, 5.0, 6.0], 1, 5);
let group3 = Matrix::from_vec(vec![3.0, 4.0, 5.0, 6.0, 7.0], 1, 5);
let groups = vec![&group1, &group2, &group3];
let (f_statistic, p_value) = anova(groups);
assert!(f_statistic > 0.0);
assert!(p_value > 0.0 && p_value < 1.0);
}
}

View File

@@ -1,9 +0,0 @@
pub mod correlation;
pub mod descriptive;
pub mod distributions;
pub mod inferential;
pub use correlation::*;
pub use descriptive::*;
pub use distributions::*;
pub use inferential::*;

View File

@@ -232,18 +232,11 @@ impl<T: Clone + PartialEq> Frame<T> {
} }
(RowIndex::Date(vals), RowIndexLookup::Date(lookup)) (RowIndex::Date(vals), RowIndexLookup::Date(lookup))
} }
Some(RowIndex::Range(ref r)) => { Some(RowIndex::Range(_)) => {
// If the length of the range does not match the number of rows, panic.
if r.end.saturating_sub(r.start) != num_rows {
panic!( panic!(
"Frame::new: Range index length ({}) mismatch matrix rows ({})", "Frame::new: Cannot explicitly provide a Range index. Use None for default range."
r.end.saturating_sub(r.start),
num_rows
); );
} }
// return the range as is.
(RowIndex::Range(r.clone()), RowIndexLookup::None)
}
None => { None => {
// Default to a sequential range index. // Default to a sequential range index.
(RowIndex::Range(0..num_rows), RowIndexLookup::None) (RowIndex::Range(0..num_rows), RowIndexLookup::None)
@@ -471,11 +464,6 @@ impl<T: Clone + PartialEq> Frame<T> {
deleted_data deleted_data
} }
/// Returns a new `Matrix` that is the transpose of the current frame's matrix.
pub fn transpose(&self) -> Matrix<T> {
self.matrix.transpose()
}
/// Sorts columns alphabetically by name, preserving data associations. /// Sorts columns alphabetically by name, preserving data associations.
pub fn sort_columns(&mut self) { pub fn sort_columns(&mut self) {
let n = self.column_names.len(); let n = self.column_names.len();
@@ -512,45 +500,6 @@ impl<T: Clone + PartialEq> Frame<T> {
} }
} }
pub fn frame_map(&self, f: impl Fn(&T) -> T) -> Frame<T> {
Frame::new(
Matrix::from_vec(
self.matrix.data().iter().map(f).collect(),
self.matrix.rows(),
self.matrix.cols(),
),
self.column_names.clone(),
Some(self.index.clone()),
)
}
pub fn frame_zip(&self, other: &Frame<T>, f: impl Fn(&T, &T) -> T) -> Frame<T> {
if self.rows() != other.rows() || self.cols() != other.cols() {
panic!(
"Frame::frame_zip: incompatible dimensions (self: {}x{}, other: {}x{})",
self.rows(),
self.cols(),
other.rows(),
other.cols()
);
}
Frame::new(
Matrix::from_vec(
self.matrix
.data()
.iter()
.zip(other.matrix.data())
.map(|(a, b)| f(a, b))
.collect(),
self.rows(),
self.cols(),
),
self.column_names.clone(),
Some(self.index.clone()),
)
}
// Internal helpers // Internal helpers
/// Rebuilds the column lookup map to match the current `column_names` ordering. /// Rebuilds the column lookup map to match the current `column_names` ordering.
@@ -832,13 +781,14 @@ impl<T: Clone + PartialEq> IndexMut<&str> for Frame<T> {
/// Panics if column labels or row indices differ between operands. /// Panics if column labels or row indices differ between operands.
macro_rules! impl_elementwise_frame_op { macro_rules! impl_elementwise_frame_op {
($OpTrait:ident, $method:ident) => { ($OpTrait:ident, $method:ident) => {
// &Frame<T> $OpTrait &Frame<T>
impl<'a, 'b, T> std::ops::$OpTrait<&'b Frame<T>> for &'a Frame<T> impl<'a, 'b, T> std::ops::$OpTrait<&'b Frame<T>> for &'a Frame<T>
where where
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>, T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
{ {
type Output = Frame<T>; type Output = Frame<T>;
fn $method(self, rhs: &'b Frame<T>) -> Frame<T> { fn $method(self, rhs: &'b Frame<T>) -> Frame<T> {
// Verify matching schema
if self.column_names != rhs.column_names { if self.column_names != rhs.column_names {
panic!( panic!(
"Element-wise {}: column names do not match. Left: {:?}, Right: {:?}", "Element-wise {}: column names do not match. Left: {:?}, Right: {:?}",
@@ -855,47 +805,21 @@ macro_rules! impl_elementwise_frame_op {
rhs.index rhs.index
); );
} }
// Apply the matrix operation
let result_matrix = (&self.matrix).$method(&rhs.matrix); let result_matrix = (&self.matrix).$method(&rhs.matrix);
// Determine index for the result
let new_index = match self.index { let new_index = match self.index {
RowIndex::Range(_) => None, RowIndex::Range(_) => None,
_ => Some(self.index.clone()), _ => Some(self.index.clone()),
}; };
Frame::new(result_matrix, self.column_names.clone(), new_index) Frame::new(result_matrix, self.column_names.clone(), new_index)
} }
} }
// Frame<T> $OpTrait &Frame<T>
impl<'b, T> std::ops::$OpTrait<&'b Frame<T>> for Frame<T>
where
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
{
type Output = Frame<T>;
fn $method(self, rhs: &'b Frame<T>) -> Frame<T> {
(&self).$method(rhs)
}
}
// &Frame<T> $OpTrait Frame<T>
impl<'a, T> std::ops::$OpTrait<Frame<T>> for &'a Frame<T>
where
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
{
type Output = Frame<T>;
fn $method(self, rhs: Frame<T>) -> Frame<T> {
self.$method(&rhs)
}
}
// Frame<T> $OpTrait Frame<T>
impl<T> std::ops::$OpTrait<Frame<T>> for Frame<T>
where
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
{
type Output = Frame<T>;
fn $method(self, rhs: Frame<T>) -> Frame<T> {
(&self).$method(&rhs)
}
}
}; };
} }
impl_elementwise_frame_op!(Add, add); impl_elementwise_frame_op!(Add, add);
impl_elementwise_frame_op!(Sub, sub); impl_elementwise_frame_op!(Sub, sub);
impl_elementwise_frame_op!(Mul, mul); impl_elementwise_frame_op!(Mul, mul);
@@ -906,10 +830,11 @@ impl_elementwise_frame_op!(Div, div);
/// Panics if column labels or row indices differ between operands. /// Panics if column labels or row indices differ between operands.
macro_rules! impl_bitwise_frame_op { macro_rules! impl_bitwise_frame_op {
($OpTrait:ident, $method:ident) => { ($OpTrait:ident, $method:ident) => {
// &Frame<bool> $OpTrait &Frame<bool>
impl<'a, 'b> std::ops::$OpTrait<&'b Frame<bool>> for &'a Frame<bool> { impl<'a, 'b> std::ops::$OpTrait<&'b Frame<bool>> for &'a Frame<bool> {
type Output = Frame<bool>; type Output = Frame<bool>;
fn $method(self, rhs: &'b Frame<bool>) -> Frame<bool> { fn $method(self, rhs: &'b Frame<bool>) -> Frame<bool> {
// Verify matching schema
if self.column_names != rhs.column_names { if self.column_names != rhs.column_names {
panic!( panic!(
"Bitwise {}: column names do not match. Left: {:?}, Right: {:?}", "Bitwise {}: column names do not match. Left: {:?}, Right: {:?}",
@@ -926,43 +851,25 @@ macro_rules! impl_bitwise_frame_op {
rhs.index rhs.index
); );
} }
// Apply the matrix operation
let result_matrix = (&self.matrix).$method(&rhs.matrix); let result_matrix = (&self.matrix).$method(&rhs.matrix);
// Determine index for the result
let new_index = match self.index { let new_index = match self.index {
RowIndex::Range(_) => None, RowIndex::Range(_) => None,
_ => Some(self.index.clone()), _ => Some(self.index.clone()),
}; };
Frame::new(result_matrix, self.column_names.clone(), new_index) Frame::new(result_matrix, self.column_names.clone(), new_index)
} }
} }
// Frame<bool> $OpTrait &Frame<bool>
impl<'b> std::ops::$OpTrait<&'b Frame<bool>> for Frame<bool> {
type Output = Frame<bool>;
fn $method(self, rhs: &'b Frame<bool>) -> Frame<bool> {
(&self).$method(rhs)
}
}
// &Frame<bool> $OpTrait Frame<bool>
impl<'a> std::ops::$OpTrait<Frame<bool>> for &'a Frame<bool> {
type Output = Frame<bool>;
fn $method(self, rhs: Frame<bool>) -> Frame<bool> {
self.$method(&rhs)
}
}
// Frame<bool> $OpTrait Frame<bool>
impl std::ops::$OpTrait<Frame<bool>> for Frame<bool> {
type Output = Frame<bool>;
fn $method(self, rhs: Frame<bool>) -> Frame<bool> {
(&self).$method(&rhs)
}
}
}; };
} }
impl_bitwise_frame_op!(BitAnd, bitand); impl_bitwise_frame_op!(BitAnd, bitand);
impl_bitwise_frame_op!(BitOr, bitor); impl_bitwise_frame_op!(BitOr, bitor);
impl_bitwise_frame_op!(BitXor, bitxor); impl_bitwise_frame_op!(BitXor, bitxor);
/* ---------- Logical NOT ---------- */
/// Implements logical NOT (`!`) for `Frame<bool>`, consuming the frame. /// Implements logical NOT (`!`) for `Frame<bool>`, consuming the frame.
impl Not for Frame<bool> { impl Not for Frame<bool> {
type Output = Frame<bool>; type Output = Frame<bool>;
@@ -981,30 +888,12 @@ impl Not for Frame<bool> {
} }
} }
/// Implements logical NOT (`!`) for `&Frame<bool>`, borrowing the frame.
impl Not for &Frame<bool> {
type Output = Frame<bool>;
fn not(self) -> Frame<bool> {
// Apply NOT to the underlying matrix
let result_matrix = !&self.matrix;
// Determine index for the result
let new_index = match self.index {
RowIndex::Range(_) => None,
_ => Some(self.index.clone()),
};
Frame::new(result_matrix, self.column_names.clone(), new_index)
}
}
// --- Tests --- // --- Tests ---
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
// Assume Matrix is available from crate::matrix or similar // Assume Matrix is available from crate::matrix or similar
use crate::matrix::{BoolOps, Matrix}; use crate::matrix::Matrix;
use chrono::NaiveDate; use chrono::NaiveDate;
// HashMap needed for direct inspection in tests if required // HashMap needed for direct inspection in tests if required
use std::collections::HashMap; use std::collections::HashMap;
@@ -1168,10 +1057,10 @@ mod tests {
Frame::new(matrix, vec!["X", "Y"], Some(index)); Frame::new(matrix, vec!["X", "Y"], Some(index));
} }
#[test] #[test]
#[should_panic(expected = "Frame::new: Range index length (4) mismatch matrix rows (3)")] #[should_panic(expected = "Cannot explicitly provide a Range index")]
fn frame_new_panic_invalid_explicit_range_index() { fn frame_new_panic_explicit_range() {
let matrix = create_test_matrix_f64(); // 3 rows let matrix = create_test_matrix_f64();
let index = RowIndex::Range(0..4); // Range 0..4 but only 3 rows let index = RowIndex::Range(0..3); // User cannot provide Range directly
Frame::new(matrix, vec!["A", "B"], Some(index)); Frame::new(matrix, vec!["A", "B"], Some(index));
} }
@@ -1460,7 +1349,7 @@ mod tests {
fn test_row_view_name_panic() { fn test_row_view_name_panic() {
let frame = create_test_frame_f64(); let frame = create_test_frame_f64();
let row_view = frame.get_row(0); let row_view = frame.get_row(0);
let _ = row_view["C"]; // Access non-existent column Z let _ = row_view["C"]; // Access non-existent column name
} }
#[test] #[test]
#[should_panic(expected = "column index 3 out of bounds")] // Check specific message #[should_panic(expected = "column index 3 out of bounds")] // Check specific message
@@ -1692,45 +1581,6 @@ mod tests {
assert_eq!(frame1.columns(), &["Z"]); assert_eq!(frame1.columns(), &["Z"]);
} }
#[test]
fn test_frame_map() {
let frame = create_test_frame_f64(); // A=[1,2,3], B=[4,5,6]
let mapped_frame = frame.frame_map(|x| x * 2.0); // Multiply each value by 2.0
assert_eq!(mapped_frame.columns(), frame.columns());
assert_eq!(mapped_frame.index(), frame.index());
assert!((mapped_frame["A"][0] - 2.0).abs() < FLOAT_TOLERANCE);
assert!((mapped_frame["A"][1] - 4.0).abs() < FLOAT_TOLERANCE);
assert!((mapped_frame["A"][2] - 6.0).abs() < FLOAT_TOLERANCE);
assert!((mapped_frame["B"][0] - 8.0).abs() < FLOAT_TOLERANCE);
assert!((mapped_frame["B"][1] - 10.0).abs() < FLOAT_TOLERANCE);
assert!((mapped_frame["B"][2] - 12.0).abs() < FLOAT_TOLERANCE);
}
#[test]
fn test_frame_zip() {
let f1 = create_test_frame_f64(); // A=[1,2,3], B=[4,5,6]
let f2 = create_test_frame_f64_alt(); // A=[0.1,0.2,0.3], B=[0.4,0.5,0.6]
let zipped_frame = f1.frame_zip(&f2, |x, y| x + y); // Element-wise addition
assert_eq!(zipped_frame.columns(), f1.columns());
assert_eq!(zipped_frame.index(), f1.index());
assert!((zipped_frame["A"][0] - 1.1).abs() < FLOAT_TOLERANCE);
assert!((zipped_frame["A"][1] - 2.2).abs() < FLOAT_TOLERANCE);
assert!((zipped_frame["A"][2] - 3.3).abs() < FLOAT_TOLERANCE);
assert!((zipped_frame["B"][0] - 4.4).abs() < FLOAT_TOLERANCE);
assert!((zipped_frame["B"][1] - 5.5).abs() < FLOAT_TOLERANCE);
assert!((zipped_frame["B"][2] - 6.6).abs() < FLOAT_TOLERANCE);
}
#[test]
#[should_panic(expected = "Frame::frame_zip: incompatible dimensions (self: 3x1, other: 3x2)")]
fn test_frame_zip_panic() {
let mut f1 = create_test_frame_f64();
let f2 = create_test_frame_f64_alt();
f1.delete_column("B");
f1.frame_zip(&f2, |x, y| x + y); // Should panic due to different column counts
}
// --- Element-wise Arithmetic Ops Tests --- // --- Element-wise Arithmetic Ops Tests ---
#[test] #[test]
fn test_frame_arithmetic_ops_f64() { fn test_frame_arithmetic_ops_f64() {
@@ -1816,79 +1666,6 @@ mod tests {
assert_eq!(frame_div["Y"], vec![10, -10]); assert_eq!(frame_div["Y"], vec![10, -10]);
} }
#[test]
fn tests_for_frame_arithmetic_ops() {
let ops: Vec<(
&str,
fn(&Frame<f64>, &Frame<f64>) -> Frame<f64>,
fn(&Frame<f64>, &Frame<f64>) -> Frame<f64>,
)> = vec![
("addition", |a, b| a + b, |a, b| (&*a) + (&*b)),
("subtraction", |a, b| a - b, |a, b| (&*a) - (&*b)),
("multiplication", |a, b| a * b, |a, b| (&*a) * (&*b)),
("division", |a, b| a / b, |a, b| (&*a) / (&*b)),
];
for (op_name, owned_op, ref_op) in ops {
let f1 = create_test_frame_f64();
let f2 = create_test_frame_f64_alt();
let result_owned = owned_op(&f1, &f2);
let expected = ref_op(&f1, &f2);
assert_eq!(
result_owned.columns(),
f1.columns(),
"Column mismatch for {}",
op_name
);
assert_eq!(
result_owned.index(),
f1.index(),
"Index mismatch for {}",
op_name
);
let bool_mat = result_owned.matrix().eq_elem(expected.matrix().clone());
assert!(bool_mat.all(), "Element-wise {} failed", op_name);
}
}
// test not , and or on frame
#[test]
fn tests_for_frame_bool_ops() {
let ops: Vec<(
&str,
fn(&Frame<bool>, &Frame<bool>) -> Frame<bool>,
fn(&Frame<bool>, &Frame<bool>) -> Frame<bool>,
)> = vec![
("and", |a, b| a & b, |a, b| (&*a) & (&*b)),
("or", |a, b| a | b, |a, b| (&*a) | (&*b)),
("xor", |a, b| a ^ b, |a, b| (&*a) ^ (&*b)),
];
for (op_name, owned_op, ref_op) in ops {
let f1 = create_test_frame_bool();
let f2 = create_test_frame_bool_alt();
let result_owned = owned_op(&f1, &f2);
let expected = ref_op(&f1, &f2);
assert_eq!(
result_owned.columns(),
f1.columns(),
"Column mismatch for {}",
op_name
);
assert_eq!(
result_owned.index(),
f1.index(),
"Index mismatch for {}",
op_name
);
let bool_mat = result_owned.matrix().eq_elem(expected.matrix().clone());
assert!(bool_mat.all(), "Element-wise {} failed", op_name);
}
}
#[test] #[test]
fn test_frame_arithmetic_ops_date_index() { fn test_frame_arithmetic_ops_date_index() {
let dates = vec![d(2024, 1, 1), d(2024, 1, 2)]; let dates = vec![d(2024, 1, 1), d(2024, 1, 2)];

View File

@@ -21,28 +21,6 @@ impl SeriesOps for Frame<f64> {
self.matrix().apply_axis(axis, f) self.matrix().apply_axis(axis, f)
} }
fn map<F>(&self, f: F) -> FloatMatrix
where
F: Fn(f64) -> f64,
{
self.matrix().map(f)
}
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
where
F: Fn(f64, f64) -> f64,
{
self.matrix().zip(other.matrix(), f)
}
fn matrix_mul(&self, other: &Self) -> FloatMatrix {
self.matrix().matrix_mul(other.matrix())
}
fn dot(&self, other: &Self) -> FloatMatrix {
self.matrix().dot(other.matrix())
}
delegate_to_matrix!( delegate_to_matrix!(
sum_vertical -> Vec<f64>, sum_vertical -> Vec<f64>,
sum_horizontal -> Vec<f64>, sum_horizontal -> Vec<f64>,
@@ -128,7 +106,7 @@ mod tests {
let col_names = vec!["A".to_string(), "B".to_string()]; let col_names = vec!["A".to_string(), "B".to_string()];
let frame = Frame::new( let frame = Frame::new(
Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]), Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]),
col_names.clone(), col_names,
None, None,
); );
assert_eq!(frame.sum_vertical(), frame.matrix().sum_vertical()); assert_eq!(frame.sum_vertical(), frame.matrix().sum_vertical());
@@ -150,33 +128,6 @@ mod tests {
); );
assert_eq!(frame.is_nan(), frame.matrix().is_nan()); assert_eq!(frame.is_nan(), frame.matrix().is_nan());
assert_eq!(frame.apply_axis(Axis::Row, |x| x[0] + x[1]), vec![4.0, 6.0]); assert_eq!(frame.apply_axis(Axis::Row, |x| x[0] + x[1]), vec![4.0, 6.0]);
assert_eq!(
frame.matrix_mul(&frame),
frame.matrix().matrix_mul(&frame.matrix())
);
assert_eq!(frame.dot(&frame), frame.matrix().dot(&frame.matrix()));
// test transpose - returns a matrix.
let frame_transposed_mat = frame.transpose();
let frame_mat_transposed = frame.matrix().transpose();
assert_eq!(frame_transposed_mat, frame_mat_transposed);
assert_eq!(frame.matrix(), &frame.matrix().transpose().transpose());
// test map
let mapped_frame = frame.map(|x| x * 2.0);
let expected_matrix = frame.matrix().map(|x| x * 2.0);
assert_eq!(mapped_frame, expected_matrix);
// test zip
let other_frame = Frame::new(
Matrix::from_cols(vec![vec![5.0, 6.0], vec![7.0, 8.0]]),
col_names.clone(),
None,
);
let zipped_frame = frame.zip(&other_frame, |x, y| x + y);
let expected_zipped_matrix = frame.matrix().zip(other_frame.matrix(), |x, y| x + y);
assert_eq!(zipped_frame, expected_zipped_matrix);
} }
#[test] #[test]

View File

@@ -8,6 +8,3 @@ pub mod frame;
/// Documentation for the [`crate::utils`] module. /// Documentation for the [`crate::utils`] module.
pub mod utils; pub mod utils;
/// Documentation for the [`crate::compute`] module.
pub mod compute;

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
pub mod boolops;
pub mod mat; pub mod mat;
pub mod seriesops; pub mod seriesops;
pub mod boolops;
pub use boolops::*;
pub use mat::*; pub use mat::*;
pub use seriesops::*; pub use seriesops::*;
pub use boolops::*;

View File

@@ -12,17 +12,6 @@ pub trait SeriesOps {
where where
F: FnMut(&[f64]) -> U; F: FnMut(&[f64]) -> U;
fn map<F>(&self, f: F) -> FloatMatrix
where
F: Fn(f64) -> f64;
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
where
F: Fn(f64, f64) -> f64;
fn matrix_mul(&self, other: &Self) -> FloatMatrix;
fn dot(&self, other: &Self) -> FloatMatrix;
fn sum_vertical(&self) -> Vec<f64>; fn sum_vertical(&self) -> Vec<f64>;
fn sum_horizontal(&self) -> Vec<f64>; fn sum_horizontal(&self) -> Vec<f64>;
@@ -150,67 +139,11 @@ impl SeriesOps for FloatMatrix {
let data = self.data().iter().map(|v| v.is_nan()).collect(); let data = self.data().iter().map(|v| v.is_nan()).collect();
BoolMatrix::from_vec(data, self.rows(), self.cols()) BoolMatrix::from_vec(data, self.rows(), self.cols())
} }
fn matrix_mul(&self, other: &Self) -> FloatMatrix {
let (m, n) = (self.rows(), self.cols());
let (n2, p) = (other.rows(), other.cols());
assert_eq!(
n, n2,
"Cannot multiply: left is {}x{}, right is {}x{}",
m, n, n2, p
);
// Column-major addressing: element (row i, col j) lives at j * m + i
let mut data = vec![0.0; m * p];
for i in 0..m {
for j in 0..p {
let mut sum = 0.0;
for k in 0..n {
sum += self[(i, k)] * other[(k, j)];
}
data[j * m + i] = sum; // <-- fixed index
}
}
FloatMatrix::from_vec(data, m, p)
}
fn dot(&self, other: &Self) -> FloatMatrix {
self.matrix_mul(other)
}
fn map<F>(&self, f: F) -> FloatMatrix
where
F: Fn(f64) -> f64,
{
let data = self.data().iter().map(|&v| f(v)).collect::<Vec<_>>();
FloatMatrix::from_vec(data, self.rows(), self.cols())
}
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
where
F: Fn(f64, f64) -> f64,
{
assert!(
self.rows() == other.rows() && self.cols() == other.cols(),
"Matrix dimensions mismatch: left is {}x{}, right is {}x{}",
self.rows(),
self.cols(),
other.rows(),
other.cols()
);
let data = self
.data()
.iter()
.zip(other.data().iter())
.map(|(&a, &b)| f(a, b))
.collect();
crate::matrix::Matrix::from_vec(data, self.rows(), self.cols())
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
// Helper function to create a FloatMatrix for SeriesOps testing // Helper function to create a FloatMatrix for SeriesOps testing
@@ -223,22 +156,6 @@ mod tests {
FloatMatrix::from_vec(data, 3, 3) FloatMatrix::from_vec(data, 3, 3)
} }
fn create_float_test_matrix_4x4() -> FloatMatrix {
// 4x4 matrix (column-major) with some NaNs
// 1.0 5.0 9.0 13.0
// 2.0 NaN 10.0 NaN
// 3.0 6.0 NaN 14.0
// NaN 7.0 11.0 NaN
// first make array with 16 elements
FloatMatrix::from_vec(
(0..16)
.map(|i| if i % 5 == 0 { f64::NAN } else { i as f64 })
.collect(),
4,
4,
)
}
// --- Tests for SeriesOps (FloatMatrix) --- // --- Tests for SeriesOps (FloatMatrix) ---
#[test] #[test]
@@ -339,90 +256,6 @@ mod tests {
assert_eq!(matrix.is_nan(), expected_matrix); assert_eq!(matrix.is_nan(), expected_matrix);
} }
#[test]
fn test_series_ops_matrix_mul() {
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix
// result should be: 23, 34, 31, 46
let expected = FloatMatrix::from_vec(vec![23.0, 34.0, 31.0, 46.0], 2, 2);
assert_eq!(a.matrix_mul(&b), expected);
assert_eq!(a.dot(&b), a.matrix_mul(&b)); // dot should be the same as matrix_mul for FloatMatrix
}
#[test]
fn test_series_ops_matrix_mul_with_nans() {
let a = create_float_test_matrix(); // 3x3 matrix with some NaNs
let b = create_float_test_matrix(); // 3x3 matrix with some NaNs
let mut result_vec = Vec::new();
result_vec.push(30.0);
for _ in 1..9 {
result_vec.push(f64::NAN);
}
let expected = FloatMatrix::from_vec(result_vec, 3, 3);
let result = a.matrix_mul(&b);
assert_eq!(result.is_nan(), expected.is_nan());
assert_eq!(
result.count_nan_horizontal(),
expected.count_nan_horizontal()
);
assert_eq!(result.count_nan_vertical(), expected.count_nan_vertical());
assert_eq!(result[(0, 0)], expected[(0, 0)]);
}
#[test]
#[should_panic(expected = "Cannot multiply: left is 3x3, right is 4x4")]
fn test_series_ops_matrix_mul_errors() {
let a = create_float_test_matrix();
let b = create_float_test_matrix_4x4();
a.dot(&b); // This should panic due to dimension mismatch
}
#[test]
fn test_series_ops_map() {
let matrix = create_float_test_matrix();
// Map function to double each value
let mapped_matrix = matrix.map(|x| x * 2.0);
// Expected data after mapping
let expected_data = vec![2.0, 4.0, 6.0, 8.0, f64::NAN, 12.0, 14.0, 16.0, f64::NAN];
let expected_matrix = FloatMatrix::from_vec(expected_data, 3, 3);
// assert_eq!(mapped_matrix, expected_matrix);
for i in 0..mapped_matrix.data().len() {
// if not nan, check equality
if !mapped_matrix.data()[i].is_nan() {
assert_eq!(mapped_matrix.data()[i], expected_matrix.data()[i]);
} else {
assert!(mapped_matrix.data()[i].is_nan());
assert!(expected_matrix.data()[i].is_nan());
}
}
assert_eq!(mapped_matrix.rows(), expected_matrix.rows());
}
#[test]
fn test_series_ops_zip() {
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix
// Zip function to add corresponding elements
let zipped_matrix = a.zip(&b, |x, y| x + y);
// Expected data after zipping
let expected_data = vec![6.0, 8.0, 10.0, 12.0];
let expected_matrix = FloatMatrix::from_vec(expected_data, 2, 2);
assert_eq!(zipped_matrix, expected_matrix);
}
#[test]
#[should_panic(expected = "Matrix dimensions mismatch: left is 2x2, right is 3x2")]
fn test_series_ops_zip_panic() {
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 3, 2); // 3x2 matrix
// This should panic due to dimension mismatch
a.zip(&b, |x, y| x + y);
}
// --- Edge Cases for SeriesOps --- // --- Edge Cases for SeriesOps ---
#[test] #[test]

2410
src/utils/bdates.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,8 @@ use std::hash::Hash;
use std::result::Result; use std::result::Result;
use std::str::FromStr; use std::str::FromStr;
// --- Core Enums ---
/// Represents the frequency at which calendar dates should be generated. /// Represents the frequency at which calendar dates should be generated.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DateFreq { pub enum DateFreq {
@@ -122,6 +124,8 @@ impl FromStr for DateFreq {
} }
} }
// --- DatesList Struct ---
/// Represents a list of calendar dates generated between a start and end date /// Represents a list of calendar dates generated between a start and end date
/// at a specified frequency. Provides methods to retrieve the full list, /// at a specified frequency. Provides methods to retrieve the full list,
/// count, or dates grouped by period. /// count, or dates grouped by period.
@@ -160,7 +164,7 @@ enum GroupKey {
/// ```rust /// ```rust
/// use chrono::NaiveDate; /// use chrono::NaiveDate;
/// use std::error::Error; /// use std::error::Error;
/// use rustframe::utils::{DatesList, DateFreq}; /// # use rustframe::utils::{DatesList, DateFreq}; // Assuming the crate/module is named 'dates'
/// ///
/// # fn main() -> Result<(), Box<dyn Error>> { /// # fn main() -> Result<(), Box<dyn Error>> {
/// let start_date = "2023-11-01".to_string(); // Wednesday /// let start_date = "2023-11-01".to_string(); // Wednesday
@@ -336,7 +340,32 @@ impl DatesList {
/// Returns an error if the start or end date strings cannot be parsed. /// Returns an error if the start or end date strings cannot be parsed.
pub fn groups(&self) -> Result<Vec<Vec<NaiveDate>>, Box<dyn Error>> { pub fn groups(&self) -> Result<Vec<Vec<NaiveDate>>, Box<dyn Error>> {
let dates = self.list()?; let dates = self.list()?;
group_dates_helper(dates, self.freq) let mut groups: HashMap<GroupKey, Vec<NaiveDate>> = HashMap::new();
for date in dates {
let key = match self.freq {
DateFreq::Daily => GroupKey::Daily(date),
DateFreq::WeeklyMonday | DateFreq::WeeklyFriday => {
let iso_week = date.iso_week();
GroupKey::Weekly(iso_week.year(), iso_week.week())
}
DateFreq::MonthStart | DateFreq::MonthEnd => {
GroupKey::Monthly(date.year(), date.month())
}
DateFreq::QuarterStart | DateFreq::QuarterEnd => {
GroupKey::Quarterly(date.year(), month_to_quarter(date.month()))
}
DateFreq::YearStart | DateFreq::YearEnd => GroupKey::Yearly(date.year()),
};
groups.entry(key).or_insert_with(Vec::new).push(date);
}
let mut sorted_groups: Vec<(GroupKey, Vec<NaiveDate>)> = groups.into_iter().collect();
sorted_groups.sort_by(|(k1, _), (k2, _)| k1.cmp(k2));
// Dates within groups are already sorted because they came from the sorted `self.list()`.
let result_groups = sorted_groups.into_iter().map(|(_, dates)| dates).collect();
Ok(result_groups)
} }
/// Returns the start date parsed as a `NaiveDate`. /// Returns the start date parsed as a `NaiveDate`.
@@ -378,6 +407,8 @@ impl DatesList {
} }
} }
// --- Dates Generator (Iterator) ---
/// An iterator that generates a sequence of calendar dates based on a start date, /// An iterator that generates a sequence of calendar dates based on a start date,
/// frequency, and a specified number of periods. /// frequency, and a specified number of periods.
/// ///
@@ -461,10 +492,10 @@ impl DatesList {
/// ``` /// ```
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct DatesGenerator { pub struct DatesGenerator {
pub freq: DateFreq, freq: DateFreq,
pub periods_remaining: usize, periods_remaining: usize,
// Stores the *next* date to be yielded by the iterator. // Stores the *next* date to be yielded by the iterator.
pub next_date_candidate: Option<NaiveDate>, next_date_candidate: Option<NaiveDate>,
} }
impl DatesGenerator { impl DatesGenerator {
@@ -530,43 +561,11 @@ impl Iterator for DatesGenerator {
} }
} }
// Internal helper functions // --- Internal helper functions ---
pub fn group_dates_helper(
dates: Vec<NaiveDate>,
freq: DateFreq,
) -> Result<Vec<Vec<NaiveDate>>, Box<dyn Error + 'static>> {
let mut groups: HashMap<GroupKey, Vec<NaiveDate>> = HashMap::new();
for date in dates {
let key = match freq {
DateFreq::Daily => GroupKey::Daily(date),
DateFreq::WeeklyMonday | DateFreq::WeeklyFriday => {
let iso_week = date.iso_week();
GroupKey::Weekly(iso_week.year(), iso_week.week())
}
DateFreq::MonthStart | DateFreq::MonthEnd => {
GroupKey::Monthly(date.year(), date.month())
}
DateFreq::QuarterStart | DateFreq::QuarterEnd => {
GroupKey::Quarterly(date.year(), month_to_quarter(date.month()))
}
DateFreq::YearStart | DateFreq::YearEnd => GroupKey::Yearly(date.year()),
};
groups.entry(key).or_insert_with(Vec::new).push(date);
}
let mut sorted_groups: Vec<(GroupKey, Vec<NaiveDate>)> = groups.into_iter().collect();
sorted_groups.sort_by(|(k1, _), (k2, _)| k1.cmp(k2));
// Dates within groups are already sorted because they came from the sorted `self.list()`.
let result_groups = sorted_groups.into_iter().map(|(_, dates)| dates).collect();
Ok(result_groups)
}
/// Generates the flat list of dates for the given range and frequency. /// Generates the flat list of dates for the given range and frequency.
/// Assumes the `collect_*` functions return sorted dates. /// Assumes the `collect_*` functions return sorted dates.
pub fn get_dates_list_with_freq( fn get_dates_list_with_freq(
start_date_str: &str, start_date_str: &str,
end_date_str: &str, end_date_str: &str,
freq: DateFreq, freq: DateFreq,
@@ -602,7 +601,7 @@ pub fn get_dates_list_with_freq(
Ok(dates) Ok(dates)
} }
// Low-Level Date Collection Functions (Internal) /* ---------------------- Low-Level Date Collection Functions (Internal) ---------------------- */
// These functions generate dates within a *range* [start_date, end_date] // These functions generate dates within a *range* [start_date, end_date]
/// Returns all calendar days day-by-day within the range. /// Returns all calendar days day-by-day within the range.
@@ -649,13 +648,8 @@ fn collect_monthly(
let mut year = start_date.year(); let mut year = start_date.year();
let mut month = start_date.month(); let mut month = start_date.month();
let next_month = |(yr, mo): (i32, u32)| -> (i32, u32) { let next_month =
if mo == 12 { |(yr, mo): (i32, u32)| -> (i32, u32) { if mo == 12 { (yr + 1, 1) } else { (yr, mo + 1) } };
(yr + 1, 1)
} else {
(yr, mo + 1)
}
};
loop { loop {
let candidate = if want_first_day { let candidate = if want_first_day {
@@ -734,21 +728,6 @@ fn collect_quarterly(
Ok(result) Ok(result)
} }
/// Returns a list of dates between the given start and end dates, inclusive,
/// at the specified frequency.
/// This function is a convenience wrapper around `get_dates_list_with_freq`.
pub fn get_dates_list_with_freq_from_naive_date(
start_date: NaiveDate,
end_date: NaiveDate,
freq: DateFreq,
) -> Result<Vec<NaiveDate>, Box<dyn Error>> {
get_dates_list_with_freq(
&start_date.format("%Y-%m-%d").to_string(),
&end_date.format("%Y-%m-%d").to_string(),
freq,
)
}
/// Return either the first or last calendar day in each year of the range. /// Return either the first or last calendar day in each year of the range.
fn collect_yearly( fn collect_yearly(
start_date: NaiveDate, start_date: NaiveDate,
@@ -778,6 +757,8 @@ fn collect_yearly(
Ok(result) Ok(result)
} }
/* ---------------------- Core Date Utility Functions (Internal) ---------------------- */
/// Given a date and a `target_weekday`, returns the date that is the first /// Given a date and a `target_weekday`, returns the date that is the first
/// `target_weekday` on or after the given date. /// `target_weekday` on or after the given date.
fn move_to_day_of_week_on_or_after( fn move_to_day_of_week_on_or_after(
@@ -833,7 +814,7 @@ fn last_day_of_month(year: i32, month: u32) -> Result<NaiveDate, Box<dyn Error>>
/// Converts a month number (1-12) to a quarter number (1-4). /// Converts a month number (1-12) to a quarter number (1-4).
/// Panics if month is invalid (should not happen with valid NaiveDate). /// Panics if month is invalid (should not happen with valid NaiveDate).
pub fn month_to_quarter(m: u32) -> u32 { fn month_to_quarter(m: u32) -> u32 {
match m { match m {
1..=3 => 1, 1..=3 => 1,
4..=6 => 2, 4..=6 => 2,
@@ -892,28 +873,9 @@ fn last_day_of_year(year: i32) -> Result<NaiveDate, Box<dyn Error>> {
// --- Generator Helper Functions --- // --- Generator Helper Functions ---
fn get_first_date_helper(freq: DateFreq) -> fn(i32, u32) -> Result<NaiveDate, Box<dyn Error>> {
if matches!(
freq,
DateFreq::Daily | DateFreq::WeeklyMonday | DateFreq::WeeklyFriday
) {
panic!("Daily, WeeklyMonday, and WeeklyFriday frequencies are not supported here");
}
match freq {
DateFreq::MonthStart => first_day_of_month,
DateFreq::MonthEnd => last_day_of_month,
DateFreq::QuarterStart => first_day_of_quarter,
DateFreq::QuarterEnd => last_day_of_quarter,
DateFreq::YearStart => |year, _| first_day_of_year(year),
DateFreq::YearEnd => |year, _| last_day_of_year(year),
_ => unreachable!(),
}
}
/// Finds the *first* valid date according to the frequency, /// Finds the *first* valid date according to the frequency,
/// starting the search *on or after* the given `start_date`. /// starting the search *on or after* the given `start_date`.
pub fn find_first_date_on_or_after( fn find_first_date_on_or_after(
start_date: NaiveDate, start_date: NaiveDate,
freq: DateFreq, freq: DateFreq,
) -> Result<NaiveDate, Box<dyn Error>> { ) -> Result<NaiveDate, Box<dyn Error>> {
@@ -921,42 +883,69 @@ pub fn find_first_date_on_or_after(
DateFreq::Daily => Ok(start_date), // The first daily date is the start date itself DateFreq::Daily => Ok(start_date), // The first daily date is the start date itself
DateFreq::WeeklyMonday => move_to_day_of_week_on_or_after(start_date, Weekday::Mon), DateFreq::WeeklyMonday => move_to_day_of_week_on_or_after(start_date, Weekday::Mon),
DateFreq::WeeklyFriday => move_to_day_of_week_on_or_after(start_date, Weekday::Fri), DateFreq::WeeklyFriday => move_to_day_of_week_on_or_after(start_date, Weekday::Fri),
DateFreq::MonthStart => {
DateFreq::MonthStart | DateFreq::MonthEnd => { let mut candidate = first_day_of_month(start_date.year(), start_date.month())?;
// let mut candidate = first_day_of_month(start_date.year(), start_date.month())?;
let get_cand_func = get_first_date_helper(freq);
let mut candidate = get_cand_func(start_date.year(), start_date.month())?;
if candidate < start_date { if candidate < start_date {
let (next_y, next_m) = if start_date.month() == 12 { let (next_y, next_m) = if start_date.month() == 12 {
(start_date.year().checked_add(1).ok_or("Year overflow")?, 1) (start_date.year().checked_add(1).ok_or("Year overflow")?, 1)
} else { } else {
(start_date.year(), start_date.month() + 1) (start_date.year(), start_date.month() + 1)
}; };
candidate = get_cand_func(next_y, next_m)?; candidate = first_day_of_month(next_y, next_m)?;
} }
Ok(candidate) Ok(candidate)
} }
DateFreq::QuarterStart | DateFreq::QuarterEnd => { DateFreq::MonthEnd => {
let mut candidate = last_day_of_month(start_date.year(), start_date.month())?;
if candidate < start_date {
let (next_y, next_m) = if start_date.month() == 12 {
(start_date.year().checked_add(1).ok_or("Year overflow")?, 1)
} else {
(start_date.year(), start_date.month() + 1)
};
candidate = last_day_of_month(next_y, next_m)?;
}
Ok(candidate)
}
DateFreq::QuarterStart => {
let current_q = month_to_quarter(start_date.month()); let current_q = month_to_quarter(start_date.month());
let get_cand_func = get_first_date_helper(freq); let mut candidate = first_day_of_quarter(start_date.year(), current_q)?;
let mut candidate = get_cand_func(start_date.year(), current_q)?;
if candidate < start_date { if candidate < start_date {
let (next_y, next_q) = if current_q == 4 { let (next_y, next_q) = if current_q == 4 {
(start_date.year().checked_add(1).ok_or("Year overflow")?, 1) (start_date.year().checked_add(1).ok_or("Year overflow")?, 1)
} else { } else {
(start_date.year(), current_q + 1) (start_date.year(), current_q + 1)
}; };
candidate = get_cand_func(next_y, next_q)?; candidate = first_day_of_quarter(next_y, next_q)?;
} }
Ok(candidate) Ok(candidate)
} }
DateFreq::QuarterEnd => {
DateFreq::YearStart | DateFreq::YearEnd => { let current_q = month_to_quarter(start_date.month());
let get_cand_func = get_first_date_helper(freq); let mut candidate = last_day_of_quarter(start_date.year(), current_q)?;
let mut candidate = get_cand_func(start_date.year(), 0)?; if candidate < start_date {
let (next_y, next_q) = if current_q == 4 {
(start_date.year().checked_add(1).ok_or("Year overflow")?, 1)
} else {
(start_date.year(), current_q + 1)
};
candidate = last_day_of_quarter(next_y, next_q)?;
}
Ok(candidate)
}
DateFreq::YearStart => {
let mut candidate = first_day_of_year(start_date.year())?;
if candidate < start_date { if candidate < start_date {
candidate = candidate =
get_cand_func(start_date.year().checked_add(1).ok_or("Year overflow")?, 0)?; first_day_of_year(start_date.year().checked_add(1).ok_or("Year overflow")?)?;
}
Ok(candidate)
}
DateFreq::YearEnd => {
let mut candidate = last_day_of_year(start_date.year())?;
if candidate < start_date {
candidate =
last_day_of_year(start_date.year().checked_add(1).ok_or("Year overflow")?)?;
} }
Ok(candidate) Ok(candidate)
} }
@@ -965,10 +954,7 @@ pub fn find_first_date_on_or_after(
/// Finds the *next* valid date according to the frequency, /// Finds the *next* valid date according to the frequency,
/// given the `current_date` (which is assumed to be a valid date previously generated). /// given the `current_date` (which is assumed to be a valid date previously generated).
pub fn find_next_date( fn find_next_date(current_date: NaiveDate, freq: DateFreq) -> Result<NaiveDate, Box<dyn Error>> {
current_date: NaiveDate,
freq: DateFreq,
) -> Result<NaiveDate, Box<dyn Error>> {
match freq { match freq {
DateFreq::Daily => current_date DateFreq::Daily => current_date
.succ_opt() .succ_opt()
@@ -976,8 +962,7 @@ pub fn find_next_date(
DateFreq::WeeklyMonday | DateFreq::WeeklyFriday => current_date DateFreq::WeeklyMonday | DateFreq::WeeklyFriday => current_date
.checked_add_signed(Duration::days(7)) .checked_add_signed(Duration::days(7))
.ok_or_else(|| "Date overflow adding 7 days".into()), .ok_or_else(|| "Date overflow adding 7 days".into()),
DateFreq::MonthStart | DateFreq::MonthEnd => { DateFreq::MonthStart => {
let get_cand_func = get_first_date_helper(freq);
let (next_y, next_m) = if current_date.month() == 12 { let (next_y, next_m) = if current_date.month() == 12 {
( (
current_date.year().checked_add(1).ok_or("Year overflow")?, current_date.year().checked_add(1).ok_or("Year overflow")?,
@@ -986,11 +971,21 @@ pub fn find_next_date(
} else { } else {
(current_date.year(), current_date.month() + 1) (current_date.year(), current_date.month() + 1)
}; };
get_cand_func(next_y, next_m) first_day_of_month(next_y, next_m)
} }
DateFreq::QuarterStart | DateFreq::QuarterEnd => { DateFreq::MonthEnd => {
let (next_y, next_m) = if current_date.month() == 12 {
(
current_date.year().checked_add(1).ok_or("Year overflow")?,
1,
)
} else {
(current_date.year(), current_date.month() + 1)
};
last_day_of_month(next_y, next_m)
}
DateFreq::QuarterStart => {
let current_q = month_to_quarter(current_date.month()); let current_q = month_to_quarter(current_date.month());
let get_cand_func = get_first_date_helper(freq);
let (next_y, next_q) = if current_q == 4 { let (next_y, next_q) = if current_q == 4 {
( (
current_date.year().checked_add(1).ok_or("Year overflow")?, current_date.year().checked_add(1).ok_or("Year overflow")?,
@@ -999,14 +994,25 @@ pub fn find_next_date(
} else { } else {
(current_date.year(), current_q + 1) (current_date.year(), current_q + 1)
}; };
get_cand_func(next_y, next_q) first_day_of_quarter(next_y, next_q)
} }
DateFreq::YearStart | DateFreq::YearEnd => { DateFreq::QuarterEnd => {
let get_cand_func = get_first_date_helper(freq); let current_q = month_to_quarter(current_date.month());
get_cand_func( let (next_y, next_q) = if current_q == 4 {
(
current_date.year().checked_add(1).ok_or("Year overflow")?, current_date.year().checked_add(1).ok_or("Year overflow")?,
0, 1,
) )
} else {
(current_date.year(), current_q + 1)
};
last_day_of_quarter(next_y, next_q)
}
DateFreq::YearStart => {
first_day_of_year(current_date.year().checked_add(1).ok_or("Year overflow")?)
}
DateFreq::YearEnd => {
last_day_of_year(current_date.year().checked_add(1).ok_or("Year overflow")?)
} }
} }
} }
@@ -1469,6 +1475,8 @@ mod tests {
Ok(()) Ok(())
} }
// --- Tests for internal helper functions ---
#[test] #[test]
fn test_move_to_day_of_week_on_or_after() -> Result<(), Box<dyn Error>> { fn test_move_to_day_of_week_on_or_after() -> Result<(), Box<dyn Error>> {
assert_eq!( assert_eq!(
@@ -1500,8 +1508,7 @@ mod tests {
// And trying to move *past* it should fail // And trying to move *past* it should fail
let day_before = NaiveDate::MAX - Duration::days(1); let day_before = NaiveDate::MAX - Duration::days(1);
let target_day_after = NaiveDate::MAX.weekday().succ(); // Day after MAX's weekday let target_day_after = NaiveDate::MAX.weekday().succ(); // Day after MAX's weekday
assert!(move_to_day_of_week_on_or_after(day_before, target_day_after).is_err()); assert!(move_to_day_of_week_on_or_after(day_before, target_day_after).is_err()); // Moving past MAX fails
// Moving past MAX fails
} }
Ok(()) Ok(())
@@ -1520,15 +1527,12 @@ mod tests {
fn test_days_in_month() -> Result<(), Box<dyn Error>> { fn test_days_in_month() -> Result<(), Box<dyn Error>> {
assert_eq!(days_in_month(2023, 1)?, 31); assert_eq!(days_in_month(2023, 1)?, 31);
assert_eq!(days_in_month(2023, 2)?, 28); assert_eq!(days_in_month(2023, 2)?, 28);
// Leap assert_eq!(days_in_month(2024, 2)?, 29); // Leap
assert_eq!(days_in_month(2024, 2)?, 29);
assert_eq!(days_in_month(2023, 4)?, 30); assert_eq!(days_in_month(2023, 4)?, 30);
assert_eq!(days_in_month(2023, 12)?, 31); assert_eq!(days_in_month(2023, 12)?, 31);
// Invalid month 0 assert!(days_in_month(2023, 0).is_err()); // Invalid month 0
assert!(days_in_month(2023, 0).is_err()); assert!(days_in_month(2023, 13).is_err()); // Invalid month 13
// Invalid month 13
// Test near max date year overflow - Use MAX.year() // Test near max date year overflow - Use MAX.year()
assert!(days_in_month(2023, 13).is_err());
assert!(days_in_month(NaiveDate::MAX.year(), 12).is_err()); assert!(days_in_month(NaiveDate::MAX.year(), 12).is_err());
Ok(()) Ok(())
} }
@@ -1538,12 +1542,9 @@ mod tests {
assert_eq!(last_day_of_month(2023, 11)?, date(2023, 11, 30)); assert_eq!(last_day_of_month(2023, 11)?, date(2023, 11, 30));
assert_eq!(last_day_of_month(2024, 2)?, date(2024, 2, 29)); // Leap assert_eq!(last_day_of_month(2024, 2)?, date(2024, 2, 29)); // Leap
assert_eq!(last_day_of_month(2023, 12)?, date(2023, 12, 31)); assert_eq!(last_day_of_month(2023, 12)?, date(2023, 12, 31));
// Invalid month 0 assert!(last_day_of_month(2023, 0).is_err()); // Invalid month 0
assert!(last_day_of_month(2023, 0).is_err()); assert!(last_day_of_month(2023, 13).is_err()); // Invalid month 13
// Invalid month 13
// Test near max date year overflow - use MAX.year() // Test near max date year overflow - use MAX.year()
assert!(last_day_of_month(2023, 13).is_err());
assert!(last_day_of_month(NaiveDate::MAX.year(), 12).is_err()); assert!(last_day_of_month(NaiveDate::MAX.year(), 12).is_err());
Ok(()) Ok(())
} }
@@ -1587,8 +1588,7 @@ mod tests {
assert_eq!(first_day_of_quarter(2023, 2)?, date(2023, 4, 1)); assert_eq!(first_day_of_quarter(2023, 2)?, date(2023, 4, 1));
assert_eq!(first_day_of_quarter(2023, 3)?, date(2023, 7, 1)); assert_eq!(first_day_of_quarter(2023, 3)?, date(2023, 7, 1));
assert_eq!(first_day_of_quarter(2023, 4)?, date(2023, 10, 1)); assert_eq!(first_day_of_quarter(2023, 4)?, date(2023, 10, 1));
// Invalid quarter assert!(first_day_of_quarter(2023, 5).is_err()); // Invalid quarter
assert!(first_day_of_quarter(2023, 5).is_err());
Ok(()) Ok(())
} }
@@ -1608,11 +1608,9 @@ mod tests {
assert_eq!(last_day_of_quarter(2023, 2)?, date(2023, 6, 30)); assert_eq!(last_day_of_quarter(2023, 2)?, date(2023, 6, 30));
assert_eq!(last_day_of_quarter(2023, 3)?, date(2023, 9, 30)); assert_eq!(last_day_of_quarter(2023, 3)?, date(2023, 9, 30));
assert_eq!(last_day_of_quarter(2023, 4)?, date(2023, 12, 31)); assert_eq!(last_day_of_quarter(2023, 4)?, date(2023, 12, 31));
// Leap year doesn't affect March end assert_eq!(last_day_of_quarter(2024, 1)?, date(2024, 3, 31)); // Leap year doesn't affect March end
assert_eq!(last_day_of_quarter(2024, 1)?, date(2024, 3, 31)); assert!(last_day_of_quarter(2023, 5).is_err()); // Invalid quarter
// Invalid quarter
// Test overflow propagation - use MAX.year() // Test overflow propagation - use MAX.year()
assert!(last_day_of_quarter(2023, 5).is_err());
assert!(last_day_of_quarter(NaiveDate::MAX.year(), 4).is_err()); assert!(last_day_of_quarter(NaiveDate::MAX.year(), 4).is_err());
Ok(()) Ok(())
} }
@@ -1629,13 +1627,16 @@ mod tests {
#[test] #[test]
fn test_last_day_of_year() -> Result<(), Box<dyn Error>> { fn test_last_day_of_year() -> Result<(), Box<dyn Error>> {
assert_eq!(last_day_of_year(2023)?, date(2023, 12, 31)); assert_eq!(last_day_of_year(2023)?, date(2023, 12, 31));
// Leap year doesn't affect Dec 31st existence assert_eq!(last_day_of_year(2024)?, date(2024, 12, 31)); // Leap year doesn't affect Dec 31st existence
// Test MAX year - should be okay since MAX is Dec 31 // Test MAX year - should be okay since MAX is Dec 31
assert_eq!(last_day_of_year(2024)?, date(2024, 12, 31));
assert_eq!(last_day_of_year(NaiveDate::MAX.year())?, NaiveDate::MAX); assert_eq!(last_day_of_year(NaiveDate::MAX.year())?, NaiveDate::MAX);
Ok(()) Ok(())
} }
// Overflow tests for collect_* removed as they were misleading
// --- Tests for Generator Helper Functions ---
#[test] #[test]
fn test_find_first_date_on_or_after() -> Result<(), Box<dyn Error>> { fn test_find_first_date_on_or_after() -> Result<(), Box<dyn Error>> {
// Daily // Daily
@@ -1643,11 +1644,10 @@ mod tests {
find_first_date_on_or_after(date(2023, 11, 8), DateFreq::Daily)?, find_first_date_on_or_after(date(2023, 11, 8), DateFreq::Daily)?,
date(2023, 11, 8) date(2023, 11, 8)
); );
// Sat -> Sat
assert_eq!( assert_eq!(
find_first_date_on_or_after(date(2023, 11, 11), DateFreq::Daily)?, find_first_date_on_or_after(date(2023, 11, 11), DateFreq::Daily)?,
date(2023, 11, 11) date(2023, 11, 11)
); ); // Sat -> Sat
// Weekly Mon // Weekly Mon
assert_eq!( assert_eq!(
@@ -1658,11 +1658,10 @@ mod tests {
find_first_date_on_or_after(date(2023, 11, 13), DateFreq::WeeklyMonday)?, find_first_date_on_or_after(date(2023, 11, 13), DateFreq::WeeklyMonday)?,
date(2023, 11, 13) date(2023, 11, 13)
); );
// Sun -> Mon
assert_eq!( assert_eq!(
find_first_date_on_or_after(date(2023, 11, 12), DateFreq::WeeklyMonday)?, find_first_date_on_or_after(date(2023, 11, 12), DateFreq::WeeklyMonday)?,
date(2023, 11, 13) date(2023, 11, 13)
); ); // Sun -> Mon
// Weekly Fri // Weekly Fri
assert_eq!( assert_eq!(
@@ -1691,11 +1690,10 @@ mod tests {
find_first_date_on_or_after(date(2023, 12, 15), DateFreq::MonthStart)?, find_first_date_on_or_after(date(2023, 12, 15), DateFreq::MonthStart)?,
date(2024, 1, 1) date(2024, 1, 1)
); );
// Oct 1 -> Oct 1
assert_eq!( assert_eq!(
find_first_date_on_or_after(date(2023, 10, 1), DateFreq::MonthStart)?, find_first_date_on_or_after(date(2023, 10, 1), DateFreq::MonthStart)?,
date(2023, 10, 1) date(2023, 10, 1)
); ); // Oct 1 -> Oct 1
// Month End // Month End
assert_eq!( assert_eq!(
@@ -1706,21 +1704,18 @@ mod tests {
find_first_date_on_or_after(date(2023, 11, 15), DateFreq::MonthEnd)?, find_first_date_on_or_after(date(2023, 11, 15), DateFreq::MonthEnd)?,
date(2023, 11, 30) date(2023, 11, 30)
); );
// Dec 31 -> Dec 31
assert_eq!( assert_eq!(
find_first_date_on_or_after(date(2023, 12, 31), DateFreq::MonthEnd)?, find_first_date_on_or_after(date(2023, 12, 31), DateFreq::MonthEnd)?,
date(2023, 12, 31) date(2023, 12, 31)
); ); // Dec 31 -> Dec 31
// Mid Feb (Leap) -> Feb 29
assert_eq!( assert_eq!(
find_first_date_on_or_after(date(2024, 2, 15), DateFreq::MonthEnd)?, find_first_date_on_or_after(date(2024, 2, 15), DateFreq::MonthEnd)?,
date(2024, 2, 29) date(2024, 2, 29)
); ); // Mid Feb (Leap) -> Feb 29
// Feb 29 -> Feb 29
assert_eq!( assert_eq!(
find_first_date_on_or_after(date(2024, 2, 29), DateFreq::MonthEnd)?, find_first_date_on_or_after(date(2024, 2, 29), DateFreq::MonthEnd)?,
date(2024, 2, 29) date(2024, 2, 29)
); ); // Feb 29 -> Feb 29
// Quarter Start // Quarter Start
assert_eq!( assert_eq!(
@@ -1925,11 +1920,13 @@ mod tests {
assert!(find_next_date(NaiveDate::MAX, DateFreq::MonthEnd).is_err()); assert!(find_next_date(NaiveDate::MAX, DateFreq::MonthEnd).is_err());
// Test finding next quarter start after Q4 MAX_YEAR -> Q1 (MAX_YEAR+1) (fail) // Test finding next quarter start after Q4 MAX_YEAR -> Q1 (MAX_YEAR+1) (fail)
assert!(find_next_date( assert!(
find_next_date(
first_day_of_quarter(NaiveDate::MAX.year(), 4)?, first_day_of_quarter(NaiveDate::MAX.year(), 4)?,
DateFreq::QuarterStart DateFreq::QuarterStart
) )
.is_err()); .is_err()
);
// Test finding next quarter end after Q3 MAX_YEAR -> Q4 MAX_YEAR (fails because last_day_of_quarter(MAX, 4) fails) // Test finding next quarter end after Q3 MAX_YEAR -> Q4 MAX_YEAR (fails because last_day_of_quarter(MAX, 4) fails)
let q3_end_max_year = last_day_of_quarter(NaiveDate::MAX.year(), 3)?; let q3_end_max_year = last_day_of_quarter(NaiveDate::MAX.year(), 3)?;
@@ -1940,18 +1937,22 @@ mod tests {
assert!(find_next_date(NaiveDate::MAX, DateFreq::QuarterEnd).is_err()); assert!(find_next_date(NaiveDate::MAX, DateFreq::QuarterEnd).is_err());
// Test finding next year start after Jan 1 MAX_YEAR -> Jan 1 (MAX_YEAR+1) (fail) // Test finding next year start after Jan 1 MAX_YEAR -> Jan 1 (MAX_YEAR+1) (fail)
assert!(find_next_date( assert!(
find_next_date(
first_day_of_year(NaiveDate::MAX.year())?, first_day_of_year(NaiveDate::MAX.year())?,
DateFreq::YearStart DateFreq::YearStart
) )
.is_err()); .is_err()
);
// Test finding next year end after Dec 31 (MAX_YEAR-1) -> Dec 31 MAX_YEAR (ok) // Test finding next year end after Dec 31 (MAX_YEAR-1) -> Dec 31 MAX_YEAR (ok)
assert!(find_next_date( assert!(
find_next_date(
last_day_of_year(NaiveDate::MAX.year() - 1)?, last_day_of_year(NaiveDate::MAX.year() - 1)?,
DateFreq::YearEnd DateFreq::YearEnd
) )
.is_ok()); .is_ok()
);
// Test finding next year end after Dec 31 MAX_YEAR -> Dec 31 (MAX_YEAR+1) (fail) // Test finding next year end after Dec 31 MAX_YEAR -> Dec 31 (MAX_YEAR+1) (fail)
assert!( assert!(
@@ -2151,15 +2152,32 @@ mod tests {
// find_first returns start_date (YE MAX-1) // find_first returns start_date (YE MAX-1)
assert_eq!(generator.next(), Some(start_date)); assert_eq!(generator.next(), Some(start_date));
// find_next finds YE(MAX) // find_next finds YE(MAX)
assert_eq!(generator.next(), Some(last_day_of_year(start_year)?)); assert_eq!(generator.next(), Some(last_day_of_year(start_year)?)); // Should be MAX
// Should be MAX
// find_next tries YE(MAX+1) - this call to find_next_date fails internally // find_next tries YE(MAX+1) - this call to find_next_date fails internally
assert_eq!(generator.next(), None); assert_eq!(generator.next(), None); // Returns None because internal find_next_date failed
// Returns None because internal find_next_date failed
// Check internal state after the call that returned None
// When Some(YE MAX) was returned, periods_remaining became 1.
// The next call enters the match, calls find_next_date (fails -> .ok() is None),
// sets next_date_candidate=None, decrements periods_remaining to 0, returns Some(YE MAX).
// --> NO, the code was: set candidate=find().ok(), THEN decrement.
// Let's revisit Iterator::next logic:
// 1. periods_remaining = 1, next_date_candidate = Some(YE MAX)
// 2. Enter match arm
// 3. find_next_date(YE MAX, YE) -> Err
// 4. self.next_date_candidate = Err.ok() -> None
// 5. self.periods_remaining -= 1 -> becomes 0
// 6. return Some(YE MAX) <-- This was the bug in my reasoning. It returns the *current* date first.
// State after returning Some(YE MAX): periods_remaining = 0, next_date_candidate = None
// Next call to generator.next():
// 1. periods_remaining = 0
// 2. Enter the `_` arm of the match
// 3. self.periods_remaining = 0 (no change)
// 4. self.next_date_candidate = None (no change)
// 5. return None
// State after the *first* None is returned: // State after the *first* None is returned:
// Corrected assertion assert_eq!(generator.periods_remaining, 0); // Corrected assertion
assert_eq!(generator.periods_remaining, 0);
assert!(generator.next_date_candidate.is_none()); assert!(generator.next_date_candidate.is_none());
// Calling next() again should also return None // Calling next() again should also return None

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +0,0 @@
pub mod bdates;
pub mod dates;
pub use bdates::{BDateFreq, BDatesGenerator, BDatesList};
pub use dates::{DateFreq, DatesGenerator, DatesList};

View File

@@ -1,4 +1,6 @@
pub mod dateutils; pub mod bdates;
pub use bdates::{BDateFreq, BDatesList, BDatesGenerator};
pub mod dates;
pub use dates::{DateFreq, DatesList, DatesGenerator};
pub use dateutils::{BDateFreq, BDatesGenerator, BDatesList};
pub use dateutils::{DateFreq, DatesGenerator, DatesList};