mirror of
https://github.com/Magnus167/rustframe.git
synced 2025-11-19 16:56:09 +00:00
Compare commits
377 Commits
e61494edc1
...
v0.0.1-a.2
| Author | SHA1 | Date | |
|---|---|---|---|
| 109d39b248 | |||
|
|
18ad6c689a | ||
| 1fead78b69 | |||
|
|
6fb32e743c | ||
| 2cb4e46217 | |||
|
|
a53ba63f30 | ||
|
|
dae60ea1bd | ||
|
|
755dee58e7 | ||
|
|
9e6e22fc37 | ||
|
|
b687fd4e6b | ||
|
|
68a01ab528 | ||
|
|
23a01dab07 | ||
|
|
f4ebd78234 | ||
|
|
1475156855 | ||
|
|
080680d095 | ||
|
|
2845f357b7 | ||
|
|
3d11226d57 | ||
|
|
039fb1a98e | ||
|
|
31a5ba2460 | ||
|
|
1a9f397702 | ||
|
|
ecd06eb352 | ||
|
|
ae327b6060 | ||
|
|
83ac9d4821 | ||
|
|
ae27ed9373 | ||
|
|
c7552f2264 | ||
|
|
3654c7053c | ||
|
|
1dcd9727b4 | ||
|
|
b62152b4f0 | ||
|
|
a6a901d6ab | ||
|
|
676af850ef | ||
|
|
ca2ca2a738 | ||
|
|
4876a74e01 | ||
|
|
b78dd75e77 | ||
|
|
9db8853d75 | ||
|
|
9738154dac | ||
| 7d0978e5fb | |||
|
|
ed01c4b8f2 | ||
|
|
e6964795e3 | ||
|
|
d1dd7ea6d2 | ||
|
|
676f78bb1e | ||
|
|
f7325a9558 | ||
|
|
18b9eef063 | ||
|
|
f99f78d508 | ||
| 2926a8a6e8 | |||
| d851c500af | |||
|
|
d741c7f472 | ||
|
|
7720312354 | ||
|
|
5509416d5f | ||
|
|
a451ba8cc7 | ||
|
|
bce1bdd21a | ||
| af70f9ffd7 | |||
|
|
7f33223496 | ||
|
|
73dbb25242 | ||
| 4061ebf8ae | |||
|
|
ef322fc6a2 | ||
|
|
750adc72e9 | ||
|
|
3207254564 | ||
|
|
2ea83727a1 | ||
|
|
3f56b378b2 | ||
|
|
afcb29e716 | ||
|
|
113831dc8c | ||
|
|
289c70d9e9 | ||
|
|
cd13d98110 | ||
|
|
b4520b0d30 | ||
|
|
5934b163f5 | ||
|
|
4a1843183a | ||
|
|
252c8a3d29 | ||
|
|
5a5baf9716 | ||
|
|
28793e5b07 | ||
|
|
d75bd7a08f | ||
|
|
6fd796cceb | ||
|
|
d0b0f295b1 | ||
| 556b08216f | |||
|
|
17201b4d29 | ||
|
|
2a99d8930c | ||
|
|
38213c73c7 | ||
|
|
c004bd8334 | ||
|
|
dccbba9d1b | ||
|
|
ab3509fef4 | ||
| f5c56d02e2 | |||
| 069ef25ef4 | |||
|
|
f9a60608df | ||
| 526e22b1b7 | |||
|
|
845667c60a | ||
|
|
3935e80be6 | ||
|
|
0ce970308b | ||
|
|
72d02e2336 | ||
|
|
26213b28d6 | ||
|
|
44ff16a0bb | ||
|
|
1192a78955 | ||
|
|
d0f9e80dfc | ||
|
|
b0d8050b11 | ||
|
|
45ec754d47 | ||
|
|
733a4da383 | ||
|
|
ded5f1aa29 | ||
|
|
fe9498963d | ||
|
|
6b580ec5eb | ||
|
|
45f147e651 | ||
| 6abf4ec983 | |||
|
|
037cfd9113 | ||
|
|
74fac9d512 | ||
| 27e9eab028 | |||
|
|
c13fcc99f7 | ||
|
|
eb9de0a647 | ||
|
|
044c3284df | ||
|
|
ad4cadd8fb | ||
| 34b09508f3 | |||
|
|
a8a532f252 | ||
|
|
19c3dde169 | ||
|
|
a335d29347 | ||
|
|
b2f6794e05 | ||
|
|
5f1f0970da | ||
|
|
7bbfb5394f | ||
|
|
285147d52b | ||
|
|
64722914bd | ||
|
|
86ea548b4f | ||
|
|
1bdcf1b113 | ||
|
|
7c7c8c2a16 | ||
|
|
4d8ed2e908 | ||
|
|
62d4803075 | ||
|
|
19bc09fd5a | ||
|
|
bda9b84987 | ||
|
|
c24eb4a08c | ||
|
|
12a72317e4 | ||
|
|
049dd02c1a | ||
|
|
bc87e40481 | ||
|
|
eebe772da6 | ||
|
|
7b0d34384a | ||
|
|
9182ab9fca | ||
|
|
de18d8e010 | ||
|
|
9b08eaeb35 | ||
|
|
a3bb509202 | ||
|
|
10018f7efe | ||
|
|
b7480b20d4 | ||
|
|
d5afb4e87a | ||
|
|
493eb96a05 | ||
|
|
58b0a5f0d9 | ||
|
|
37c0d312e5 | ||
|
|
e7c181f011 | ||
|
|
2cd2e24f57 | ||
|
|
61aeedbf76 | ||
|
|
8ffa278db8 | ||
|
|
b2a799fc30 | ||
|
|
5779c6b82d | ||
|
|
a2fcaf1d52 | ||
|
|
6711cad6e2 | ||
|
|
46cfe43983 | ||
|
|
122a972a33 | ||
|
|
2a63e6d5ab | ||
|
|
e48ce7d6d7 | ||
|
|
a08fb546a9 | ||
|
|
e195481691 | ||
|
|
87d14bbf5f | ||
|
|
4f8a27298c | ||
|
|
4648800a09 | ||
|
|
96f434bf94 | ||
|
|
46abeb12a7 | ||
|
|
75d07371b2 | ||
|
|
70d2a7a2b4 | ||
|
|
261d0d7007 | ||
|
|
005c10e816 | ||
|
|
4c626bf09c | ||
|
|
ab6d5f9f8f | ||
|
|
1c8fcc0bad | ||
|
|
2ca496cfd1 | ||
|
|
85154a3be0 | ||
|
|
54a266b630 | ||
|
|
4ddacdfd21 | ||
|
|
37b20f2174 | ||
|
|
b279131503 | ||
|
|
eb948c1f49 | ||
|
|
d4c0f174b1 | ||
|
|
b6645fcfbd | ||
|
|
b1b7e63fea | ||
|
|
e2c5e65c18 | ||
|
|
be41e9b20e | ||
|
|
1501ed5b7a | ||
|
|
dbbf5f9617 | ||
|
|
6718cf5de7 | ||
|
|
f749b2c921 | ||
|
|
04637ef4d0 | ||
| 11330e464b | |||
|
|
162a09fc22 | ||
| 5db5475a61 | |||
|
|
2da1e9bf04 | ||
|
|
601dc66d7d | ||
|
|
073a22b866 | ||
|
|
ffa1a76df4 | ||
|
|
621632b7d2 | ||
|
|
64e578fae2 | ||
| f44bb5b205 | |||
|
|
fde9c73a66 | ||
|
|
85e0eb7e67 | ||
| b8e64811ed | |||
|
|
2ac2db258f | ||
|
|
7a68d13eb3 | ||
|
|
f39c678192 | ||
|
|
ad9f89860e | ||
|
|
ef574892fa | ||
|
|
30bff6ecf4 | ||
|
|
9daf583a4d | ||
|
|
1ebc3546d2 | ||
|
|
ffe635f1c4 | ||
|
|
75f194b8c9 | ||
|
|
3e279b8138 | ||
|
|
a06725686f | ||
|
|
7106f6ccf7 | ||
|
|
95911fb0ef | ||
|
|
84b7fc8eec | ||
| 7a580327e8 | |||
|
|
1e908b0ac5 | ||
|
|
1f306c59f2 | ||
|
|
48660b1a75 | ||
|
|
6858fa4bfa | ||
| 545c1d7117 | |||
|
|
da72f18b42 | ||
| e7d4e2221c | |||
| c7e39e6f99 | |||
| 36588e38f6 | |||
|
|
bc5cc65897 | ||
|
|
2477b3e8f6 | ||
|
|
60639ade01 | ||
| 6e0ea441e4 | |||
|
|
453e34ef82 | ||
|
|
092a7b7cce | ||
| 349ae52629 | |||
|
|
7b33b5b582 | ||
|
|
9279939c30 | ||
| 092aebd5fc | |||
|
|
f1ab33ed96 | ||
|
|
80196f8a53 | ||
| afe0e15ede | |||
| 91b4ed9c56 | |||
|
|
bc00035bc0 | ||
|
|
4e98aa1490 | ||
|
|
8f3bc6843b | ||
|
|
127e178b79 | ||
|
|
52729e55be | ||
|
|
c37659f09d | ||
|
|
629c9d84e2 | ||
|
|
f5fd935475 | ||
| 84170d7ae7 | |||
|
|
86fec0a6b1 | ||
|
|
63e5f4b3a4 | ||
|
|
df28440dfd | ||
|
|
f4d52f0ce8 | ||
|
|
e5d8f4386c | ||
|
|
d8239114cf | ||
|
|
c1834c0141 | ||
|
|
fb63b3fd3d | ||
|
|
b1a2a2260d | ||
| 77983eef77 | |||
| e9eeac0c40 | |||
|
|
85482d9569 | ||
|
|
c05f1696f0 | ||
|
|
e3f4749709 | ||
|
|
643c897479 | ||
|
|
876f1ccbf3 | ||
|
|
9702b6d5c4 | ||
|
|
6a9b828ada | ||
|
|
1a5b8919d3 | ||
|
|
2e980a78fa | ||
|
|
dfe259a371 | ||
|
|
4e74c2dcfe | ||
|
|
498f822672 | ||
|
|
34809656f6 | ||
|
|
894b85b384 | ||
|
|
b758b22b93 | ||
|
|
bb0bffba73 | ||
|
|
bca1121004 | ||
|
|
b3b0e5e3ae | ||
|
|
eeabfcfff6 | ||
|
|
eab1c5eec1 | ||
|
|
efe44c7399 | ||
|
|
33fea1d126 | ||
|
|
f0de677b69 | ||
|
|
bbcdbb4151 | ||
|
|
aec6278a50 | ||
|
|
054b3c828e | ||
|
|
659e93c27d | ||
|
|
f520f29f11 | ||
|
|
db8b756a74 | ||
|
|
d8154a1175 | ||
|
|
4d4b6d1656 | ||
| 2bc375f950 | |||
|
|
37d10cbc7d | ||
|
|
1c41d387ef | ||
|
|
38baf0c648 | ||
|
|
2e6c4bd6bb | ||
|
|
6bb1c2a0de | ||
|
|
d670ab4a5c | ||
|
|
eb09593b0a | ||
|
|
29c304d512 | ||
|
|
7c96439550 | ||
|
|
76b8824ce3 | ||
|
|
00befe7ee4 | ||
|
|
41349e2dba | ||
|
|
9ab3a7c2c1 | ||
|
|
88a08e1063 | ||
|
|
9696cd47bd | ||
|
|
1ad25d7e0b | ||
|
|
00889754ea | ||
| 4a5485cddd | |||
| ce98d74836 | |||
|
|
de549d15d5 | ||
| 97e8f3d55c | |||
| bccde27fb8 | |||
|
|
345166b16c | ||
|
|
5e03e7f1e5 | ||
|
|
d42437bf32 | ||
|
|
eeba602c14 | ||
|
|
7e27539529 | ||
|
|
80083ce77c | ||
|
|
179553fe38 | ||
|
|
eb21f21dcb | ||
|
|
5f1bb68723 | ||
|
|
ac4d67146e | ||
|
|
5cca2ca149 | ||
|
|
186dffe29b | ||
|
|
88d1d6b3ff | ||
|
|
6d79797266 | ||
|
|
a62c24304e | ||
|
|
3f11068f9c | ||
|
|
18887e72ae | ||
|
|
02c00cc657 | ||
|
|
e74617f2e4 | ||
|
|
63f1a83c2d | ||
|
|
96775d28ee | ||
| 88b608ea1b | |||
| ddcc539076 | |||
|
|
3795602e4b | ||
| dcc3d6eb6c | |||
|
|
bd42e2b2c9 | ||
|
|
b4069fa18b | ||
|
|
7a3ddc9147 | ||
| 80779ac6e5 | |||
| efbe769d4a | |||
|
|
204cee0f47 | ||
|
|
d7385f789b | ||
|
|
0e7c18656c | ||
| d42da5ad7b | |||
| fc508bb1a9 | |||
|
|
c7f3cc6ef9 | ||
|
|
eb28b93781 | ||
|
|
33d021072e | ||
|
|
8ba39f8e87 | ||
|
|
cae230783e | ||
|
|
101ee8c1c7 | ||
|
|
b25c6d358e | ||
|
|
aaac6bfc3b | ||
|
|
6076e0c8e3 | ||
|
|
d15aaf437f | ||
| 3b67d6e477 | |||
|
|
f084eae72c | ||
|
|
860cd4d081 | ||
| 04fbf5ca50 | |||
|
|
49c9e7f66d | ||
|
|
08a4871907 | ||
|
|
dbb95c9b1a | ||
|
|
8417b9e5d7 | ||
| d81da1786d | |||
|
|
8876379443 | ||
|
|
4fe81b7a26 | ||
|
|
aca1e1381f | ||
|
|
6d0430e5ed | ||
|
|
ae2daeeda7 | ||
| 4cb095b71f | |||
|
|
a59c4c56bf | ||
|
|
8cb33bc096 | ||
|
|
bb34991a5f | ||
|
|
676ae413a6 | ||
|
|
4e70e868fd | ||
| e2db5eb315 | |||
| bdef7f1732 | |||
| 2b4ef8a371 | |||
|
|
1213d588ec | ||
|
|
a76963ec2e |
73
.github/actions/runner-fallback/action.yml
vendored
Normal file
73
.github/actions/runner-fallback/action.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
# actions/runner-fallback/action.yml
|
||||
name: "Runner Fallback"
|
||||
description: |
|
||||
Chooses a self-hosted runner when one with the required labels is online,
|
||||
otherwise returns a fallback GitHub-hosted label.
|
||||
inputs:
|
||||
primary-runner:
|
||||
description: 'Comma-separated label list for the preferred self-hosted runner (e.g. "self-hosted,linux")'
|
||||
required: true
|
||||
fallback-runner:
|
||||
description: 'Comma-separated label list or single label for the fallback (e.g. "ubuntu-latest")'
|
||||
required: true
|
||||
github-token:
|
||||
description: 'GitHub token with repo admin read permissions'
|
||||
required: true
|
||||
|
||||
|
||||
|
||||
outputs:
|
||||
use-runner:
|
||||
description: "JSON array of labels you can feed straight into runs-on"
|
||||
value: ${{ steps.pick.outputs.use-runner }}
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Check self-hosted fleet
|
||||
id: pick
|
||||
shell: bash
|
||||
env:
|
||||
TOKEN: ${{ inputs.github-token }}
|
||||
PRIMARY: ${{ inputs.primary-runner }}
|
||||
FALLBACK: ${{ inputs.fallback-runner }}
|
||||
run: |
|
||||
# -------- helper -----------
|
||||
to_json_array () {
|
||||
local list="$1"; IFS=',' read -ra L <<<"$list"
|
||||
printf '['; printf '"%s",' "${L[@]}"; printf ']'
|
||||
}
|
||||
# -------- query API ---------
|
||||
repo="${{ github.repository }}"
|
||||
runners=$(curl -s -H "Authorization: Bearer $TOKEN" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
"https://api.github.com/repos/$repo/actions/runners?per_page=100")
|
||||
|
||||
# Debug: Print runners content
|
||||
# echo "Runners response: $runners"
|
||||
|
||||
# Check if runners is null or empty
|
||||
if [ -z "$runners" ] || [ "$runners" = "null" ]; then
|
||||
echo "❌ Error: Unable to fetch runners or no runners found." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Process runners only if valid
|
||||
IFS=',' read -ra WANT <<<"$PRIMARY"
|
||||
online_found=0
|
||||
while read -r row; do
|
||||
labels=$(jq -r '.labels[].name' <<<"$row")
|
||||
ok=1
|
||||
for w in "${WANT[@]}"; do
|
||||
grep -Fxq "$w" <<<"$labels" || { ok=0; break; }
|
||||
done
|
||||
[ "$ok" -eq 1 ] && { online_found=1; break; }
|
||||
done < <(jq -c '.runners[] | select(.status=="online")' <<<"$runners")
|
||||
|
||||
if [ "$online_found" -eq 1 ]; then
|
||||
echo "✅ Self-hosted runner online."
|
||||
echo "use-runner=$(to_json_array "$PRIMARY")" >>"$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "❌ No matching self-hosted runner online - using fallback."
|
||||
echo "use-runner=$(to_json_array "$FALLBACK")" >>"$GITHUB_OUTPUT"
|
||||
fi
|
||||
81
.github/htmldocs/index.html
vendored
Normal file
81
.github/htmldocs/index.html
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Rustframe</title>
|
||||
<link rel="icon" type="image/png" href="./rustframe_logo.png">
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
background-color: #2b2b2b;
|
||||
color: #d4d4d4;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100vh;
|
||||
}
|
||||
|
||||
main {
|
||||
text-align: center;
|
||||
padding: 20px;
|
||||
background-color: #3c3c3c;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
|
||||
max-width: 600px;
|
||||
}
|
||||
|
||||
img {
|
||||
max-width: 100px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
h1 {
|
||||
/* logo is b35f20 */
|
||||
color: #f8813f;
|
||||
}
|
||||
|
||||
a {
|
||||
color: #ff9a60;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<main>
|
||||
<h1>
|
||||
<img src="./rustframe_logo.png" alt="Rustframe Logo"><br>
|
||||
Rustframe
|
||||
</h1>
|
||||
<h2>A lightweight dataframe & math toolkit for Rust</h2>
|
||||
<hr style="border: 1px solid #d4d4d4; margin: 20px 0;">
|
||||
<p>
|
||||
|
||||
🐙 <a href="https://github.com/Magnus167/rustframe">GitHub</a>
|
||||
<br><br>
|
||||
|
||||
📖 <a href="https://magnus167.github.io/rustframe/user-guide">User Guide</a>
|
||||
<br><br>
|
||||
|
||||
|
||||
📚 <a href="https://magnus167.github.io/rustframe/docs">Docs</a> |
|
||||
📊 <a href="https://magnus167.github.io/rustframe/benchmark-report/">Benchmarks</a>
|
||||
|
||||
<br><br>
|
||||
🦀 <a href="https://crates.io/crates/rustframe">Crates.io</a> |
|
||||
🔖 <a href="https://docs.rs/rustframe/latest/rustframe/">docs.rs</a>
|
||||
<br><br>
|
||||
<!-- 🌐 <a href="https://gitea.nulltech.uk/Magnus167/rustframe">Gitea mirror</a> -->
|
||||
</p>
|
||||
</main>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
30
.github/runners/runner-arm/Dockerfile
vendored
Normal file
30
.github/runners/runner-arm/Dockerfile
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
FROM ubuntu:latest
|
||||
|
||||
ARG RUNNER_VERSION="2.323.0"
|
||||
|
||||
# Prevents installdependencies.sh from prompting the user and blocking the image creation
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update -y && apt upgrade -y && useradd -m docker
|
||||
RUN apt install -y --no-install-recommends \
|
||||
curl jq build-essential libssl-dev libffi-dev python3 python3-venv python3-dev python3-pip \
|
||||
# dot net core dependencies
|
||||
libicu74 libssl3 libkrb5-3 zlib1g libcurl4
|
||||
|
||||
|
||||
RUN cd /home/docker && mkdir actions-runner && cd actions-runner \
|
||||
&& curl -O -L https://github.com/actions/runner/releases/download/v${RUNNER_VERSION}/actions-runner-linux-arm64-${RUNNER_VERSION}.tar.gz \
|
||||
&& tar xzf ./actions-runner-linux-arm64-${RUNNER_VERSION}.tar.gz
|
||||
|
||||
RUN chown -R docker ~docker && /home/docker/actions-runner/bin/installdependencies.sh
|
||||
|
||||
COPY entrypoint.sh entrypoint.sh
|
||||
|
||||
# make the script executable
|
||||
RUN chmod +x entrypoint.sh
|
||||
|
||||
# since the config and run script for actions are not allowed to be run by root,
|
||||
# set the user to "docker" so all subsequent commands are run as the docker user
|
||||
USER docker
|
||||
|
||||
ENTRYPOINT ["./entrypoint.sh"]
|
||||
18
.github/runners/runner-arm/docker-compose.yml
vendored
Normal file
18
.github/runners/runner-arm/docker-compose.yml
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
# docker-compose.yml
|
||||
|
||||
services:
|
||||
github-runner:
|
||||
build:
|
||||
context: .
|
||||
args:
|
||||
RUNNER_VERSION: 2.323.0
|
||||
# container_name commented to allow for multiple runners
|
||||
# container_name: github-runner
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- runner-work:/home/runner/actions-runner/_work
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
runner-work:
|
||||
24
.github/runners/runner-arm/entrypoint.sh
vendored
Normal file
24
.github/runners/runner-arm/entrypoint.sh
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
#!/bin/bash
|
||||
|
||||
REPOSITORY=$REPO
|
||||
ACCESS_TOKEN=$GH_TOKEN
|
||||
LABELS=$RUNNER_LABELS
|
||||
|
||||
# echo "REPO ${REPOSITORY}"
|
||||
# echo "ACCESS_TOKEN ${ACCESS_TOKEN}"
|
||||
|
||||
REG_TOKEN=$(curl -X POST -H "Authorization: token ${ACCESS_TOKEN}" -H "Accept: application/vnd.github+json" https://api.github.com/repos/${REPOSITORY}/actions/runners/registration-token | jq .token --raw-output)
|
||||
|
||||
cd /home/docker/actions-runner
|
||||
|
||||
./config.sh --url https://github.com/${REPOSITORY} --token ${REG_TOKEN} --labels ${LABELS}
|
||||
|
||||
cleanup() {
|
||||
echo "Removing runner..."
|
||||
./config.sh remove --unattended --token ${REG_TOKEN}
|
||||
}
|
||||
|
||||
trap 'cleanup; exit 130' INT
|
||||
trap 'cleanup; exit 143' TERM
|
||||
|
||||
./run.sh & wait $!
|
||||
9
.github/runners/runner-arm/example.env
vendored
Normal file
9
.github/runners/runner-arm/example.env
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
|
||||
# Repository name
|
||||
REPO="Magnus167/rustframe"
|
||||
|
||||
# GitHub runner token
|
||||
GH_TOKEN="some_token_here"
|
||||
|
||||
# Labels for the runner
|
||||
RUNNER_LABELS=self-hosted-linux,linux
|
||||
4
.github/runners/runner-arm/start.sh
vendored
Normal file
4
.github/runners/runner-arm/start.sh
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
|
||||
|
||||
docker compose up -d --build
|
||||
# docker compose up -d --build --scale github-runner=2
|
||||
45
.github/runners/runner-x64/Dockerfile
vendored
Normal file
45
.github/runners/runner-x64/Dockerfile
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
FROM ubuntu:latest
|
||||
|
||||
ARG RUNNER_VERSION="2.323.0"
|
||||
|
||||
# Prevents installdependencies.sh from prompting the user and blocking the image creation
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update -y && apt upgrade -y && useradd -m docker
|
||||
RUN apt install -y --no-install-recommends \
|
||||
curl jq git zip unzip \
|
||||
# dev dependencies
|
||||
build-essential libssl-dev libffi-dev python3 python3-venv python3-dev python3-pip \
|
||||
# dot net core dependencies
|
||||
libicu74 libssl3 libkrb5-3 zlib1g libcurl4 \
|
||||
# Rust and Cargo dependencies
|
||||
gcc cmake
|
||||
|
||||
# Install GitHub CLI
|
||||
RUN curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg | dd of=/usr/share/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& chmod go+r /usr/share/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
|
||||
&& apt update -y && apt install -y gh \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Rust and Cargo
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
ENV PATH="/home/docker/.cargo/bin:${PATH}"
|
||||
ENV HOME="/home/docker"
|
||||
|
||||
RUN cd /home/docker && mkdir actions-runner && cd actions-runner \
|
||||
&& curl -O -L https://github.com/actions/runner/releases/download/v${RUNNER_VERSION}/actions-runner-linux-x64-${RUNNER_VERSION}.tar.gz \
|
||||
&& tar xzf ./actions-runner-linux-x64-${RUNNER_VERSION}.tar.gz
|
||||
|
||||
RUN chown -R docker ~docker && /home/docker/actions-runner/bin/installdependencies.sh
|
||||
|
||||
COPY entrypoint.sh entrypoint.sh
|
||||
|
||||
# make the script executable
|
||||
RUN chmod +x entrypoint.sh
|
||||
|
||||
# since the config and run script for actions are not allowed to be run by root,
|
||||
# set the user to "docker" so all subsequent commands are run as the docker user
|
||||
USER docker
|
||||
|
||||
ENTRYPOINT ["./entrypoint.sh"]
|
||||
18
.github/runners/runner-x64/docker-compose.yml
vendored
Normal file
18
.github/runners/runner-x64/docker-compose.yml
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
# docker-compose.yml
|
||||
|
||||
services:
|
||||
github-runner:
|
||||
build:
|
||||
context: .
|
||||
args:
|
||||
RUNNER_VERSION: 2.323.0
|
||||
# container_name commented to allow for multiple runners
|
||||
# container_name: github-runner
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- runner-work:/home/runner/actions-runner/_work
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
runner-work:
|
||||
24
.github/runners/runner-x64/entrypoint.sh
vendored
Normal file
24
.github/runners/runner-x64/entrypoint.sh
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
#!/bin/bash
|
||||
|
||||
REPOSITORY=$REPO
|
||||
ACCESS_TOKEN=$GH_TOKEN
|
||||
LABELS=$RUNNER_LABELS
|
||||
|
||||
# echo "REPO ${REPOSITORY}"
|
||||
# echo "ACCESS_TOKEN ${ACCESS_TOKEN}"
|
||||
|
||||
REG_TOKEN=$(curl -X POST -H "Authorization: token ${ACCESS_TOKEN}" -H "Accept: application/vnd.github+json" https://api.github.com/repos/${REPOSITORY}/actions/runners/registration-token | jq .token --raw-output)
|
||||
|
||||
cd /home/docker/actions-runner
|
||||
|
||||
./config.sh --url https://github.com/${REPOSITORY} --token ${REG_TOKEN} --labels ${LABELS}
|
||||
|
||||
cleanup() {
|
||||
echo "Removing runner..."
|
||||
./config.sh remove --unattended --token ${REG_TOKEN}
|
||||
}
|
||||
|
||||
trap 'cleanup; exit 130' INT
|
||||
trap 'cleanup; exit 143' TERM
|
||||
|
||||
./run.sh & wait $!
|
||||
9
.github/runners/runner-x64/example.env
vendored
Normal file
9
.github/runners/runner-x64/example.env
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
|
||||
# Repository name
|
||||
REPO="Magnus167/rustframe"
|
||||
|
||||
# GitHub runner token
|
||||
GH_TOKEN="some_token_here"
|
||||
|
||||
# Labels for the runner
|
||||
RUNNER_LABELS=self-hosted-linux,linux
|
||||
4
.github/runners/runner-x64/start.sh
vendored
Normal file
4
.github/runners/runner-x64/start.sh
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
|
||||
|
||||
docker compose up -d --build
|
||||
# docker compose up -d --build --scale github-runner=2
|
||||
426
.github/scripts/custom_benchmark_report.py
vendored
Normal file
426
.github/scripts/custom_benchmark_report.py
vendored
Normal file
@@ -0,0 +1,426 @@
|
||||
# create_benchmark_table.py
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import pandas as pd
|
||||
import html # Import the html module for escaping
|
||||
|
||||
|
||||
# Regular expression to parse "test_name (size)" format
|
||||
DIR_PATTERN = re.compile(r"^(.*?) \((.*?)\)$")
|
||||
|
||||
# Standard location for criterion estimates relative to the benchmark dir
|
||||
ESTIMATES_PATH_NEW = Path("new") / "estimates.json"
|
||||
# Fallback location (older versions or baseline comparisons)
|
||||
ESTIMATES_PATH_BASE = Path("base") / "estimates.json"
|
||||
|
||||
# Standard location for the HTML report relative to the benchmark's specific directory
|
||||
REPORT_HTML_RELATIVE_PATH = Path("report") / "index.html"
|
||||
|
||||
|
||||
def get_default_criterion_report_path() -> Path:
|
||||
"""
|
||||
Returns the default path for the Criterion benchmark report.
|
||||
This is typically 'target/criterion'.
|
||||
"""
|
||||
return Path("target") / "criterion" / "report" / "index.html"
|
||||
|
||||
|
||||
def load_criterion_reports(
|
||||
criterion_root_dir: Path,
|
||||
) -> Dict[str, Dict[str, Dict[str, Any]]]:
|
||||
"""
|
||||
Loads Criterion benchmark results from a specified directory and finds HTML paths.
|
||||
|
||||
Args:
|
||||
criterion_root_dir: The Path object pointing to the main
|
||||
'target/criterion' directory.
|
||||
|
||||
Returns:
|
||||
A nested dictionary structured as:
|
||||
{ test_name: { size: {'json': json_content, 'html_path': relative_html_path}, ... }, ... }
|
||||
Returns an empty dict if the root directory is not found or empty.
|
||||
"""
|
||||
results: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(dict)
|
||||
|
||||
if not criterion_root_dir.is_dir():
|
||||
print(
|
||||
f"Error: Criterion root directory not found or is not a directory: {criterion_root_dir}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return {}
|
||||
|
||||
print(f"Scanning for benchmark reports in: {criterion_root_dir}")
|
||||
|
||||
for item in criterion_root_dir.iterdir():
|
||||
if not item.is_dir():
|
||||
continue
|
||||
|
||||
match = DIR_PATTERN.match(item.name)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
test_name = match.group(1).strip()
|
||||
size = match.group(2).strip()
|
||||
benchmark_dir_name = item.name
|
||||
benchmark_dir_path = item
|
||||
|
||||
json_path: Optional[Path] = None
|
||||
|
||||
if (benchmark_dir_path / ESTIMATES_PATH_NEW).is_file():
|
||||
json_path = benchmark_dir_path / ESTIMATES_PATH_NEW
|
||||
elif (benchmark_dir_path / ESTIMATES_PATH_BASE).is_file():
|
||||
json_path = benchmark_dir_path / ESTIMATES_PATH_BASE
|
||||
|
||||
html_path = benchmark_dir_path / REPORT_HTML_RELATIVE_PATH
|
||||
|
||||
if json_path is None or not json_path.is_file():
|
||||
print(
|
||||
f"Warning: Could not find estimates JSON in {benchmark_dir_path}. Skipping benchmark size '{test_name} ({size})'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
if not html_path.is_file():
|
||||
print(
|
||||
f"Warning: Could not find HTML report at expected location {html_path}. Skipping benchmark size '{test_name} ({size})'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
with json_path.open("r", encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
results[test_name][size] = {
|
||||
"json": json_data,
|
||||
"html_path_relative_to_criterion_root": str(
|
||||
Path(benchmark_dir_name) / REPORT_HTML_RELATIVE_PATH
|
||||
).replace("\\", "/"),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: Failed to decode JSON from {json_path}", file=sys.stderr)
|
||||
except IOError as e:
|
||||
print(f"Error: Failed to read file {json_path}: {e}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error: An unexpected error occurred loading {json_path}: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
return dict(results)
|
||||
|
||||
|
||||
def format_nanoseconds(ns: float) -> str:
|
||||
"""Formats nanoseconds into a human-readable string with units."""
|
||||
if pd.isna(ns):
|
||||
return "-"
|
||||
if ns < 1_000:
|
||||
return f"{ns:.2f} ns"
|
||||
elif ns < 1_000_000:
|
||||
return f"{ns / 1_000:.2f} µs"
|
||||
elif ns < 1_000_000_000:
|
||||
return f"{ns / 1_000_000:.2f} ms"
|
||||
else:
|
||||
return f"{ns / 1_000_000_000:.2f} s"
|
||||
|
||||
|
||||
def generate_html_table_with_links(
|
||||
results: Dict[str, Dict[str, Dict[str, Any]]], html_base_path: str
|
||||
) -> str:
|
||||
"""
|
||||
Generates a full HTML page with a styled table from benchmark results.
|
||||
"""
|
||||
css_styles = """
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol";
|
||||
line-height: 1.6;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background-color: #f4f7f6;
|
||||
color: #333;
|
||||
}
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 20px auto;
|
||||
padding: 20px;
|
||||
background-color: #fff;
|
||||
box-shadow: 0 0 15px rgba(0,0,0,0.1);
|
||||
border-radius: 8px;
|
||||
}
|
||||
h1 {
|
||||
color: #2c3e50;
|
||||
text-align: center;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
p.subtitle {
|
||||
text-align: center;
|
||||
margin-bottom: 8px;
|
||||
color: #555;
|
||||
font-size: 0.95em;
|
||||
}
|
||||
p.note {
|
||||
text-align: center;
|
||||
margin-bottom: 25px;
|
||||
color: #777;
|
||||
font-size: 0.85em;
|
||||
}
|
||||
.benchmark-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin-top: 25px;
|
||||
box-shadow: 0 2px 8px rgba(0,0,0,0.05);
|
||||
}
|
||||
.benchmark-table th, .benchmark-table td {
|
||||
border: 1px solid #dfe6e9; /* Lighter border */
|
||||
padding: 12px 15px;
|
||||
}
|
||||
.benchmark-table th {
|
||||
background-color: #3498db; /* Primary blue */
|
||||
color: #ffffff;
|
||||
font-weight: 600; /* Slightly bolder */
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
text-align: center; /* Center align headers */
|
||||
}
|
||||
.benchmark-table td {
|
||||
text-align: right; /* Default for data cells (times) */
|
||||
}
|
||||
.benchmark-table td:first-child { /* Benchmark Name column */
|
||||
font-weight: 500;
|
||||
color: #2d3436;
|
||||
text-align: left; /* Left align benchmark names */
|
||||
}
|
||||
.benchmark-table tbody tr:nth-child(even) {
|
||||
background-color: #f8f9fa; /* Very light grey for even rows */
|
||||
}
|
||||
.benchmark-table tbody tr:hover {
|
||||
background-color: #e9ecef; /* Slightly darker on hover */
|
||||
}
|
||||
.benchmark-table a {
|
||||
color: #2980b9; /* Link blue */
|
||||
text-decoration: none;
|
||||
font-weight: 500;
|
||||
}
|
||||
.benchmark-table a:hover {
|
||||
text-decoration: underline;
|
||||
color: #1c5a81; /* Darker blue on hover */
|
||||
}
|
||||
.no-results {
|
||||
text-align: center;
|
||||
font-size: 1.2em;
|
||||
color: #7f8c8d;
|
||||
margin-top: 30px;
|
||||
}
|
||||
</style>
|
||||
"""
|
||||
|
||||
html_doc_start = f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Criterion Benchmark Results</title>
|
||||
{css_styles}
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1 id="criterion-benchmark-results">Criterion Benchmark Results</h1>
|
||||
"""
|
||||
|
||||
html_doc_end = """
|
||||
</div>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
if not results:
|
||||
return f"""{html_doc_start}
|
||||
<p class="no-results">No benchmark results found or loaded.</p>
|
||||
{html_doc_end}"""
|
||||
|
||||
all_sizes = sorted(
|
||||
list(set(size for test_data in results.values() for size in test_data.keys())),
|
||||
key=(lambda x: int(x.split("x")[0])),
|
||||
)
|
||||
all_test_names = sorted(list(results.keys()))
|
||||
|
||||
table_content = """
|
||||
<p class="subtitle">Each cell links to the detailed Criterion.rs report for that specific benchmark size.</p>
|
||||
<p class="note">Note: Values shown are the midpoint of the mean confidence interval, formatted for readability.</p>
|
||||
<p class="note"><a href="report/index.html">[Switch to the standard Criterion.rs report]</a></p>
|
||||
<table class="benchmark-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Benchmark Name</th>
|
||||
"""
|
||||
|
||||
for size in all_sizes:
|
||||
table_content += f"<th>{html.escape(size)}</th>\n"
|
||||
|
||||
table_content += """
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
|
||||
for test_name in all_test_names:
|
||||
table_content += f"<tr>\n"
|
||||
table_content += f" <td>{html.escape(test_name)}</td>\n"
|
||||
|
||||
for size in all_sizes:
|
||||
cell_data = results.get(test_name, {}).get(size)
|
||||
mean_value = pd.NA
|
||||
full_report_url = "#"
|
||||
|
||||
if (
|
||||
cell_data
|
||||
and "json" in cell_data
|
||||
and "html_path_relative_to_criterion_root" in cell_data
|
||||
):
|
||||
try:
|
||||
mean_data = cell_data["json"].get("mean")
|
||||
if mean_data and "confidence_interval" in mean_data:
|
||||
ci = mean_data["confidence_interval"]
|
||||
if "lower_bound" in ci and "upper_bound" in ci:
|
||||
lower, upper = ci["lower_bound"], ci["upper_bound"]
|
||||
if isinstance(lower, (int, float)) and isinstance(
|
||||
upper, (int, float)
|
||||
):
|
||||
mean_value = (lower + upper) / 2.0
|
||||
else:
|
||||
print(
|
||||
f"Warning: Non-numeric bounds for {test_name} ({size}).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Warning: Missing confidence_interval bounds for {test_name} ({size}).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Warning: Missing 'mean' data for {test_name} ({size}).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
relative_report_path = cell_data[
|
||||
"html_path_relative_to_criterion_root"
|
||||
]
|
||||
joined_path = Path(html_base_path) / relative_report_path
|
||||
full_report_url = str(joined_path).replace("\\", "/")
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error processing cell data for {test_name} ({size}): {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
formatted_mean = format_nanoseconds(mean_value)
|
||||
|
||||
if full_report_url and full_report_url != "#":
|
||||
table_content += f' <td><a href="{html.escape(full_report_url)}">{html.escape(formatted_mean)}</a></td>\n'
|
||||
else:
|
||||
table_content += f" <td>{html.escape(formatted_mean)}</td>\n"
|
||||
table_content += "</tr>\n"
|
||||
|
||||
table_content += """
|
||||
</tbody>
|
||||
</table>
|
||||
"""
|
||||
return f"{html_doc_start}{table_content}{html_doc_end}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
DEFAULT_CRITERION_PATH = "target/criterion"
|
||||
DEFAULT_OUTPUT_FILE = "./target/criterion/index.html"
|
||||
DEFAULT_HTML_BASE_PATH = ""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Load Criterion benchmark results from JSON files and generate an HTML table with links to reports."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Perform a dry run without writing the HTML file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--criterion-dir",
|
||||
type=str,
|
||||
default=DEFAULT_CRITERION_PATH,
|
||||
help=f"Path to the main 'target/criterion' directory (default: {DEFAULT_CRITERION_PATH}) containing benchmark data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--html-base-path",
|
||||
type=str,
|
||||
default=DEFAULT_HTML_BASE_PATH,
|
||||
help=(
|
||||
f"Prefix for HTML links to individual benchmark reports. "
|
||||
f"This is prepended to each report's relative path (e.g., 'benchmark_name/report/index.html'). "
|
||||
f"If the main output HTML (default: '{DEFAULT_OUTPUT_FILE}') is in the 'target/criterion/' directory, "
|
||||
f"this should typically be empty (default: '{DEFAULT_HTML_BASE_PATH}'). "
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
type=str,
|
||||
default=DEFAULT_OUTPUT_FILE,
|
||||
help=f"Path to save the generated HTML summary report (default: {DEFAULT_OUTPUT_FILE}).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dry_run:
|
||||
print(
|
||||
"Dry run mode: No files will be written. Use --dry-run to skip writing the HTML file."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
criterion_path = Path(args.criterion_dir)
|
||||
output_file_path = Path(args.output_file)
|
||||
|
||||
try:
|
||||
output_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as e:
|
||||
print(
|
||||
f"Error: Could not create output directory {output_file_path.parent}: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
all_results = load_criterion_reports(criterion_path)
|
||||
|
||||
# Generate HTML output regardless of whether results were found (handles "no results" page)
|
||||
html_output = generate_html_table_with_links(all_results, args.html_base_path)
|
||||
|
||||
if not all_results:
|
||||
print("\nNo benchmark results found or loaded.")
|
||||
# Fallthrough to write the "no results" page generated by generate_html_table_with_links
|
||||
else:
|
||||
print("\nSuccessfully loaded benchmark results.")
|
||||
# pprint(all_results) # Uncomment for debugging
|
||||
|
||||
print(
|
||||
f"Generating HTML report with links using HTML base path: '{args.html_base_path}'"
|
||||
)
|
||||
|
||||
try:
|
||||
with output_file_path.open("w", encoding="utf-8") as f:
|
||||
f.write(html_output)
|
||||
print(f"\nSuccessfully wrote HTML report to {output_file_path}")
|
||||
if not all_results:
|
||||
sys.exit(1) # Exit with error code if no results, though file is created
|
||||
sys.exit(0)
|
||||
except IOError as e:
|
||||
print(f"Error writing HTML output to {output_file_path}: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred while writing HTML: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
16
.github/scripts/run_examples.sh
vendored
Normal file
16
.github/scripts/run_examples.sh
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
cargo build --release --examples
|
||||
|
||||
for ex in examples/*.rs; do
|
||||
name=$(basename "$ex" .rs)
|
||||
echo
|
||||
echo "🟡 Running example: $name"
|
||||
|
||||
if ! cargo run --release --example "$name" -- --debug; then
|
||||
echo
|
||||
echo "❌ Example '$name' failed. Aborting."
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
echo
|
||||
echo "✅ All examples ran successfully."
|
||||
126
.github/workflows/docs-and-testcov.yml
vendored
126
.github/workflows/docs-and-testcov.yml
vendored
@@ -7,9 +7,13 @@ concurrency:
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
# pull_request:
|
||||
# branches: [main]
|
||||
# pull_request:
|
||||
# branches: [main]
|
||||
workflow_dispatch:
|
||||
workflow_run:
|
||||
workflows: ["run-benchmarks"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -17,8 +21,23 @@ permissions:
|
||||
pages: write
|
||||
|
||||
jobs:
|
||||
docs-and-testcov:
|
||||
pick-runner:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
runner: ${{ steps.choose.outputs.use-runner }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- id: choose
|
||||
uses: ./.github/actions/runner-fallback
|
||||
with:
|
||||
primary-runner: "self-hosted"
|
||||
fallback-runner: "ubuntu-latest"
|
||||
github-token: ${{ secrets.CUSTOM_GH_TOKEN }}
|
||||
|
||||
docs-and-testcov:
|
||||
needs: pick-runner
|
||||
runs-on: ${{ fromJson(needs.pick-runner.outputs.runner) }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -29,6 +48,14 @@ jobs:
|
||||
toolchain: stable
|
||||
override: true
|
||||
|
||||
- name: Replace logo URL in README.md
|
||||
env:
|
||||
LOGO_URL: ${{ secrets.LOGO_URL }}
|
||||
run: |
|
||||
# replace with EXAMPLE.COM/LOGO
|
||||
|
||||
sed -i 's|.github/rustframe_logo.png|rustframe_logo.png|g' README.md
|
||||
|
||||
- name: Build documentation
|
||||
run: cargo doc --no-deps --release
|
||||
|
||||
@@ -54,12 +81,10 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Export tarpaulin coverage badge JSON
|
||||
# extract raw coverage and round to 2 decimal places
|
||||
run: |
|
||||
# extract raw coverage
|
||||
coverage=$(jq '.coverage' tarpaulin-report.json)
|
||||
# round to 2 decimal places
|
||||
formatted=$(printf "%.2f" "$coverage")
|
||||
# build the badge JSON using the pre-formatted string
|
||||
jq --arg message "$formatted" \
|
||||
'{schemaVersion:1,
|
||||
label:"tarpaulin-report",
|
||||
@@ -79,23 +104,92 @@ jobs:
|
||||
<(echo '{}') \
|
||||
> last-commit-date.json
|
||||
|
||||
- name: Download last available benchmark report
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.CUSTOM_GH_TOKEN }}
|
||||
run: |
|
||||
artifact_url=$(
|
||||
curl -sSL \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${GH_TOKEN}" \
|
||||
"https://api.github.com/repos/${{ github.repository }}/actions/artifacts" \
|
||||
| jq -r '
|
||||
.artifacts[]
|
||||
| select(.name | startswith("benchmark-reports"))
|
||||
| .archive_download_url
|
||||
' \
|
||||
| head -n 1
|
||||
)
|
||||
|
||||
if [ -z "$artifact_url" ]; then
|
||||
echo "No benchmark artifact found!"
|
||||
mkdir -p benchmark-report
|
||||
echo '<!DOCTYPE html><html><head><title>No Benchmarks</title></head><body><h1>No benchmarks available</h1></body></html>' > benchmark-report/index.html
|
||||
exit 0
|
||||
fi
|
||||
|
||||
curl -L -H "Authorization: Bearer ${GH_TOKEN}" \
|
||||
"$artifact_url" -o benchmark-report.zip
|
||||
|
||||
# Print all files in the current directory
|
||||
echo "Files in the current directory:"
|
||||
ls -al
|
||||
|
||||
# check if the zip file is valid
|
||||
if ! unzip -tq benchmark-report.zip; then
|
||||
echo "benchmark-report.zip is invalid or corrupted!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
unzip -q benchmark-report.zip -d benchmark-report
|
||||
|
||||
# echo "<meta http-equiv=\"refresh\" content=\"0; url=report/index.html\">" > benchmark-report/index.html
|
||||
|
||||
- name: Copy files to output directory
|
||||
run: |
|
||||
# mkdir docs
|
||||
mkdir -p target/doc/docs
|
||||
mv target/doc/rustframe/* target/doc/docs/
|
||||
|
||||
echo "<meta http-equiv=\"refresh\" content=\"0; url=../docs/index.html\">" > target/doc/rustframe/index.html
|
||||
|
||||
cp tarpaulin-report.html target/doc/docs/
|
||||
cp tarpaulin-report.json target/doc/docs/
|
||||
cp tarpaulin-badge.json target/doc/docs/
|
||||
cp last-commit-date.json target/doc/docs/
|
||||
# cp -r .github target/doc/docs
|
||||
cp .github/rustframe_logo.png target/doc/docs/
|
||||
# echo "<meta http-equiv=\"refresh\" content=\"0; url=docs\">" > target/doc/index.html
|
||||
touch target/doc/.nojekyll
|
||||
|
||||
# copy the benchmark report to the output directory
|
||||
cp -r benchmark-report target/doc/
|
||||
|
||||
mkdir output
|
||||
cp tarpaulin-report.html target/doc/rustframe/
|
||||
cp tarpaulin-report.json target/doc/rustframe/
|
||||
cp tarpaulin-badge.json target/doc/rustframe/
|
||||
cp last-commit-date.json target/doc/rustframe/
|
||||
mkdir -p target/doc/rustframe/.github
|
||||
cp .github/rustframe_logo.png target/doc/rustframe/.github/
|
||||
echo "<meta http-equiv=\"refresh\" content=\"0; url=rustframe\">" > target/doc/index.html
|
||||
cp -r target/doc/* output/
|
||||
|
||||
- name: Build user guide
|
||||
run: |
|
||||
cargo binstall mdbook
|
||||
bash ./docs/build.sh
|
||||
|
||||
- name: Copy user guide to output directory
|
||||
run: |
|
||||
mkdir output/user-guide
|
||||
cp -r docs/book/* output/user-guide/
|
||||
|
||||
- name: Add index.html to output directory
|
||||
run: |
|
||||
cp .github/htmldocs/index.html output/index.html
|
||||
cp .github/rustframe_logo.png output/rustframe_logo.png
|
||||
|
||||
- name: Upload Pages artifact
|
||||
if: github.event_name == 'push'
|
||||
# if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
|
||||
uses: actions/upload-pages-artifact@v3
|
||||
with:
|
||||
path: target/doc/
|
||||
# path: target/doc/
|
||||
path: output/
|
||||
|
||||
- name: Deploy to GitHub Pages
|
||||
if: github.event_name == 'push'
|
||||
# if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
|
||||
uses: actions/deploy-pages@v4
|
||||
|
||||
63
.github/workflows/run-benchmarks.yml
vendored
Normal file
63
.github/workflows/run-benchmarks.yml
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
name: run-benchmarks
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
pick-runner:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
runner: ${{ steps.choose.outputs.use-runner }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- id: choose
|
||||
uses: ./.github/actions/runner-fallback
|
||||
with:
|
||||
primary-runner: "self-hosted"
|
||||
fallback-runner: "ubuntu-latest"
|
||||
github-token: ${{ secrets.CUSTOM_GH_TOKEN }}
|
||||
|
||||
run-benchmarks:
|
||||
needs: pick-runner
|
||||
runs-on: ${{ fromJson(needs.pick-runner.outputs.runner) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
- name: Setup venv
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install pandas
|
||||
uv run .github/scripts/custom_benchmark_report.py --dry-run
|
||||
|
||||
- name: Run benchmarks
|
||||
run: cargo bench --features bench
|
||||
|
||||
- name: Generate custom benchmark reports
|
||||
run: |
|
||||
if [ -d ./target/criterion ]; then
|
||||
echo "Found benchmark reports, generating custom report..."
|
||||
else
|
||||
echo "No benchmark reports found, skipping custom report generation."
|
||||
exit 1
|
||||
fi
|
||||
uv run .github/scripts/custom_benchmark_report.py
|
||||
|
||||
- name: Upload benchmark reports
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: benchmark-reports-${{ github.sha }}
|
||||
path: ./target/criterion/
|
||||
56
.github/workflows/run-unit-tests.yml
vendored
56
.github/workflows/run-unit-tests.yml
vendored
@@ -11,25 +11,62 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
pick-runner:
|
||||
if: github.event.pull_request.draft == false
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
runner: ${{ steps.choose.outputs.use-runner }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- id: choose
|
||||
uses: ./.github/actions/runner-fallback
|
||||
with:
|
||||
primary-runner: "self-hosted"
|
||||
fallback-runner: "ubuntu-latest"
|
||||
github-token: ${{ secrets.CUSTOM_GH_TOKEN }}
|
||||
|
||||
run-unit-tests:
|
||||
needs: pick-runner
|
||||
if: github.event.pull_request.draft == false
|
||||
name: run-unit-tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ fromJson(needs.pick-runner.outputs.runner) }}
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install Rust
|
||||
run: rustup update stable
|
||||
- name: Set up Rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
override: true
|
||||
- name: Install cargo-llvm-cov
|
||||
uses: taiki-e/install-action@cargo-llvm-cov
|
||||
- name: Generate code coverage
|
||||
run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
|
||||
- name: Run doc-tests
|
||||
run: cargo test --doc --all-features --workspace --release
|
||||
|
||||
- name: Run doctests
|
||||
run: cargo test --doc --release
|
||||
|
||||
- name: Run unit tests with code coverage
|
||||
run: cargo llvm-cov --release --lcov --output-path lcov.info
|
||||
|
||||
- name: Test docs generation
|
||||
run: cargo doc --no-deps --release
|
||||
|
||||
- name: Test examples
|
||||
run: cargo test --examples --release
|
||||
|
||||
- name: Run all examples
|
||||
run: |
|
||||
for example in examples/*.rs; do
|
||||
name=$(basename "$example" .rs)
|
||||
echo "Running example: $name"
|
||||
cargo run --release --example "$name" -- --debug || exit 1
|
||||
done
|
||||
|
||||
- name: Cargo test all targets
|
||||
run: cargo test --all-targets --release
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
@@ -41,3 +78,8 @@ jobs:
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Test build user guide
|
||||
run: |
|
||||
cargo binstall mdbook
|
||||
bash ./docs/build.sh
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -15,3 +15,7 @@ data/
|
||||
.vscode/
|
||||
|
||||
tarpaulin-report.*
|
||||
|
||||
.github/htmldocs/rustframe_logo.png
|
||||
|
||||
docs/book/
|
||||
795
Cargo.lock
generated
795
Cargo.lock
generated
@@ -1,795 +0,0 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "android-tzdata"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anes"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf"
|
||||
|
||||
[[package]]
|
||||
name = "cast"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e3a13707ac958681c13b39b458c073d0d9bc8a22cb1b2f4c8e55eb72c13f362"
|
||||
dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c"
|
||||
dependencies = [
|
||||
"android-tzdata",
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits",
|
||||
"wasm-bindgen",
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ciborium"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
|
||||
dependencies = [
|
||||
"ciborium-io",
|
||||
"ciborium-ll",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ciborium-io"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
|
||||
|
||||
[[package]]
|
||||
name = "ciborium-ll"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
|
||||
dependencies = [
|
||||
"ciborium-io",
|
||||
"half",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "3.2.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"clap_lex",
|
||||
"indexmap",
|
||||
"textwrap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_lex"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5"
|
||||
dependencies = [
|
||||
"os_str_bytes",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation-sys"
|
||||
version = "0.8.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e7c76e09c1aae2bc52b3d2f29e13c6572553b30c4aa1b8a49fd70de6412654cb"
|
||||
dependencies = [
|
||||
"anes",
|
||||
"atty",
|
||||
"cast",
|
||||
"ciborium",
|
||||
"clap",
|
||||
"criterion-plot",
|
||||
"itertools",
|
||||
"lazy_static",
|
||||
"num-traits",
|
||||
"oorandom",
|
||||
"plotters",
|
||||
"rayon",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"tinytemplate",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion-plot"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
|
||||
dependencies = [
|
||||
"cast",
|
||||
"itertools",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||
|
||||
[[package]]
|
||||
name = "crunchy"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929"
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.63"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8"
|
||||
dependencies = [
|
||||
"android_system_properties",
|
||||
"core-foundation-sys",
|
||||
"iana-time-zone-haiku",
|
||||
"js-sys",
|
||||
"log",
|
||||
"wasm-bindgen",
|
||||
"windows-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone-haiku"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "1.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.10.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.77"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.172"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
|
||||
[[package]]
|
||||
name = "oorandom"
|
||||
version = "11.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
|
||||
|
||||
[[package]]
|
||||
name = "os_str_bytes"
|
||||
version = "6.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1"
|
||||
|
||||
[[package]]
|
||||
name = "plotters"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"plotters-backend",
|
||||
"plotters-svg",
|
||||
"wasm-bindgen",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "plotters-backend"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
|
||||
|
||||
[[package]]
|
||||
name = "plotters-svg"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
|
||||
dependencies = [
|
||||
"plotters-backend",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.95"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
|
||||
dependencies = [
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
|
||||
dependencies = [
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-automata",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||
|
||||
[[package]]
|
||||
name = "rustframe"
|
||||
version = "0.0.1-a.0"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"criterion",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.219"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.219"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.140"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.16.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057"
|
||||
|
||||
[[package]]
|
||||
name = "tinytemplate"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
|
||||
dependencies = [
|
||||
"same-file",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
"rustversion",
|
||||
"wasm-bindgen-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-backend"
|
||||
version = "0.2.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6"
|
||||
dependencies = [
|
||||
"bumpalo",
|
||||
"log",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro"
|
||||
version = "0.2.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"wasm-bindgen-macro-support",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro-support"
|
||||
version = "0.2.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-shared"
|
||||
version = "0.2.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web-sys"
|
||||
version = "0.3.77"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||
dependencies = [
|
||||
"winapi-i686-pc-windows-gnu",
|
||||
"winapi-x86_64-pc-windows-gnu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-i686-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
|
||||
dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.61.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980"
|
||||
dependencies = [
|
||||
"windows-implement",
|
||||
"windows-interface",
|
||||
"windows-link",
|
||||
"windows-result",
|
||||
"windows-strings",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-implement"
|
||||
version = "0.60.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-interface"
|
||||
version = "0.59.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-link"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38"
|
||||
|
||||
[[package]]
|
||||
name = "windows-result"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-strings"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a2ba9642430ee452d5a7aa78d72907ebe8cfda358e8cb7918a2050581322f97"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.59.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm",
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_i686_gnullvm",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_gnullvm",
|
||||
"windows_x86_64_msvc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
12
Cargo.toml
12
Cargo.toml
@@ -1,10 +1,12 @@
|
||||
[package]
|
||||
name = "rustframe"
|
||||
version = "0.0.1-a.0"
|
||||
authors = ["Palash Tyagi (https://github.com/Magnus167)"]
|
||||
version = "0.0.1-a.20250805"
|
||||
edition = "2021"
|
||||
license = "GPL-3.0-or-later"
|
||||
readme = "README.md"
|
||||
description = "A simple dataframe library"
|
||||
description = "A simple dataframe and math toolkit"
|
||||
documentation = "https://magnus167.github.io/rustframe/"
|
||||
|
||||
[lib]
|
||||
name = "rustframe"
|
||||
@@ -13,10 +15,12 @@ crate-type = ["cdylib", "lib"]
|
||||
|
||||
[dependencies]
|
||||
chrono = "^0.4.10"
|
||||
criterion = { version = "0.5", features = ["html_reports"], optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.4", features = ["html_reports"] }
|
||||
[features]
|
||||
bench = ["dep:criterion"]
|
||||
|
||||
[[bench]]
|
||||
name = "benchmarks"
|
||||
harness = false
|
||||
required-features = ["bench"]
|
||||
|
||||
161
README.md
161
README.md
@@ -1,38 +1,63 @@
|
||||
# rustframe
|
||||
|
||||
# <img align="center" alt="Rustframe" src=".github/rustframe_logo.png" height="50" /> rustframe
|
||||
|
||||
<!-- though the centre tag doesn't work as it would noramlly, it achieves the desired effect -->
|
||||
|
||||
📚 [Docs](https://magnus167.github.io/rustframe/) | 🐙 [GitHub](https://github.com/Magnus167/rustframe) | 🌐 [Gitea mirror](https://gitea.nulltech.uk/Magnus167/rustframe) | 🦀 [Crates.io](https://crates.io/crates/rustframe) | 🔖 [docs.rs](https://docs.rs/rustframe/latest/rustframe/)
|
||||
🐙 [GitHub](https://github.com/Magnus167/rustframe) | 📚 [Docs](https://magnus167.github.io/rustframe/) | 📖 [User Guide](https://magnus167.github.io/rustframe/user-guide/) | 🦀 [Crates.io](https://crates.io/crates/rustframe) | 🔖 [docs.rs](https://docs.rs/rustframe/latest/rustframe/)
|
||||
|
||||
<!-- [](https://github.com/Magnus167/rustframe) -->
|
||||
|
||||
[](https://codecov.io/gh/Magnus167/rustframe)
|
||||
[](https://magnus167.github.io/rustframe/rustframe/tarpaulin-report.html)
|
||||
[](https://magnus167.github.io/rustframe/docs/tarpaulin-report.html)
|
||||
[](https://gitea.nulltech.uk/Magnus167/rustframe)
|
||||
|
||||
---
|
||||
|
||||
## Rustframe: *A lightweight dataframe & math toolkit for Rust*
|
||||
## Rustframe: _A lightweight dataframe & math toolkit for Rust_
|
||||
|
||||
Rustframe provides intuitive dataframe, matrix, and series operations small-to-mid scale data analysis and manipulation.
|
||||
Rustframe provides intuitive dataframe, matrix, and series operations for data analysis and manipulation.
|
||||
|
||||
Rustframe keeps things simple, safe, and readable. It is handy for quick numeric experiments and small analytical tasks, but it is **not** meant to compete with powerhouse crates like `polars` or `ndarray`.
|
||||
Rustframe keeps things simple, safe, and readable. It is handy for quick numeric experiments and small analytical tasks as well as for educational purposes. It is designed to be easy to use and understand, with a clean API implemented in 100% safe Rust.
|
||||
|
||||
Rustframe is an educational project, and is not intended for production use. It is **not** meant to compete with powerhouse crates like `polars` or `ndarray`. It is a work in progress, and the API is subject to change. There are no guarantees of stability or performance, and it is not optimized for large datasets or high-performance computing.
|
||||
|
||||
### What it offers
|
||||
|
||||
- **Math that reads like math** - element‑wise `+`, `−`, `×`, `÷` on entire frames or scalars.
|
||||
- **Broadcast & reduce** - sum, product, any/all across rows or columns without boilerplate.
|
||||
- **Boolean masks made simple** - chain comparisons, combine with `&`/`|`, get a tidy `BoolMatrix` back.
|
||||
- **Date‑centric row index** - business‑day ranges and calendar slicing built in.
|
||||
- **Pure safe Rust** - 100 % safe, zero `unsafe`.
|
||||
- **Matrix operations** - Element-wise arithmetic, boolean logic, transpose, and more.
|
||||
- **Math that reads like math** - element-wise `+`, `−`, `×`, `÷` on entire frames or scalars.
|
||||
- **Frames** - Column major data structure for single-type data, with labeled columns and typed row indices.
|
||||
- **Compute module** - Implements various statistical computations and machine learning models.
|
||||
- **Random number utils** - Built-in pseudo and cryptographically secure generators for simulations.
|
||||
- **[Coming Soon]** _DataFrame_ - Multi-type data structure for heterogeneous data, with labeled columns and typed row indices.
|
||||
|
||||
#### Matrix and Frame functionality
|
||||
|
||||
- **Matrix operations** - Element-wise arithmetic, boolean logic, transpose, and more.
|
||||
- **Frame operations** - Column manipulation, sorting, and more.
|
||||
|
||||
#### Compute Module
|
||||
|
||||
The `compute` module provides implementations for various statistical computations and machine learning models.
|
||||
|
||||
**Statistics, Data Analysis, and Machine Learning:**
|
||||
|
||||
- Correlation analysis
|
||||
- Descriptive statistics
|
||||
- Distributions
|
||||
- Inferential statistics
|
||||
|
||||
- Dense Neural Networks
|
||||
- Gaussian Naive Bayes
|
||||
- K-Means Clustering
|
||||
- Linear Regression
|
||||
- Logistic Regression
|
||||
- Principal Component Analysis
|
||||
|
||||
### Heads up
|
||||
|
||||
- **Not memory‑efficient (yet)** - footprint needs work.
|
||||
- **Feature set still small** - expect missing pieces.
|
||||
- **The feature set is still limited** - expect missing pieces.
|
||||
|
||||
### On the horizon
|
||||
### Somewhere down the line
|
||||
|
||||
- Optional GPU help (Vulkan or similar) for heavier workloads.
|
||||
- Optional GPU acceleration (Vulkan or similar) for heavier workloads.
|
||||
- Straightforward Python bindings using `pyo3`.
|
||||
|
||||
---
|
||||
@@ -44,17 +69,16 @@ use chrono::NaiveDate;
|
||||
use rustframe::{
|
||||
frame::{Frame, RowIndex},
|
||||
matrix::{BoolOps, Matrix, SeriesOps},
|
||||
utils::{BDateFreq, BDatesList},
|
||||
utils::{DateFreq, BDatesList},
|
||||
};
|
||||
|
||||
let n_periods = 4;
|
||||
|
||||
// Four business days starting 2024‑01‑02
|
||||
// Four business days starting 2024-01-02
|
||||
let dates: Vec<NaiveDate> =
|
||||
BDatesList::from_n_periods("2024-01-02".to_string(), BDateFreq::Daily, n_periods)
|
||||
BDatesList::from_n_periods("2024-01-02".to_string(), DateFreq::Daily, n_periods)
|
||||
.unwrap()
|
||||
.list()
|
||||
.unwrap();
|
||||
.list().unwrap();
|
||||
|
||||
let col_names: Vec<String> = vec!["a".to_string(), "b".to_string()];
|
||||
|
||||
@@ -85,17 +109,104 @@ let result: Matrix<f64> = result / 2.0; // divide by scalar
|
||||
let check: bool = result.eq_elem(ma.clone()).all();
|
||||
assert!(check);
|
||||
|
||||
// The above math can also be written as:
|
||||
// Alternatively:
|
||||
let check: bool = (&(&(&(&ma + 1.0) - 1.0) * 2.0) / 2.0)
|
||||
.eq_elem(ma.clone())
|
||||
.all();
|
||||
assert!(check);
|
||||
|
||||
// The above math can also be written as:
|
||||
// or even as:
|
||||
let check: bool = ((((ma.clone() + 1.0) - 1.0) * 2.0) / 2.0)
|
||||
.eq_elem(ma)
|
||||
.eq_elem(ma.clone())
|
||||
.all();
|
||||
assert!(check);
|
||||
|
||||
// Matrix multiplication
|
||||
let mc: Matrix<f64> = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
||||
let md: Matrix<f64> = Matrix::from_cols(vec![vec![5.0, 6.0], vec![7.0, 8.0]]);
|
||||
let mul_result: Matrix<f64> = mc.matrix_mul(&md);
|
||||
// Expected:
|
||||
assert_eq!(mul_result.data(), &[23.0, 34.0, 31.0, 46.0]);
|
||||
|
||||
// Dot product (alias for matrix_mul for FloatMatrix)
|
||||
let dot_result: Matrix<f64> = mc.dot(&md);
|
||||
assert_eq!(dot_result, mul_result);
|
||||
|
||||
// Transpose
|
||||
let original_matrix: Matrix<f64> = Matrix::from_cols(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
|
||||
let transposed_matrix: Matrix<f64> = original_matrix.transpose();
|
||||
assert_eq!(transposed_matrix.rows(), 2);
|
||||
assert_eq!(transposed_matrix.cols(), 3);
|
||||
assert_eq!(transposed_matrix.data(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
|
||||
|
||||
// Map
|
||||
let matrix = Matrix::from_cols(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
|
||||
// Map function to double each value
|
||||
let mapped_matrix = matrix.map(|x| x * 2.0);
|
||||
assert_eq!(mapped_matrix.data(), &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
|
||||
|
||||
// Zip
|
||||
let a = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); // 2x2 matrix
|
||||
let b = Matrix::from_cols(vec![vec![5.0, 6.0], vec![7.0, 8.0]]); // 2x2 matrix
|
||||
// Zip function to add corresponding elements
|
||||
let zipped_matrix = a.zip(&b, |x, y| x + y);
|
||||
assert_eq!(zipped_matrix.data(), &[6.0, 8.0, 10.0, 12.0]);
|
||||
```
|
||||
|
||||
## More examples
|
||||
|
||||
See the [examples](./examples/) directory for some demonstrations of Rustframe's syntax and functionality.
|
||||
|
||||
To run the examples, use:
|
||||
|
||||
```bash
|
||||
cargo run --example <example_name>
|
||||
```
|
||||
|
||||
E.g. to run the `game_of_life` example:
|
||||
|
||||
```bash
|
||||
cargo run --example game_of_life
|
||||
```
|
||||
|
||||
More demos:
|
||||
|
||||
```bash
|
||||
cargo run --example linear_regression
|
||||
cargo run --example logistic_regression
|
||||
cargo run --example k_means
|
||||
cargo run --example pca
|
||||
cargo run --example stats_overview
|
||||
cargo run --example descriptive_stats
|
||||
cargo run --example correlation
|
||||
cargo run --example inferential_stats
|
||||
cargo run --example distributions
|
||||
```
|
||||
|
||||
To simply list all available examples, you can run:
|
||||
|
||||
```bash
|
||||
# this technically raises an error, but it will list all examples
|
||||
cargo run --example
|
||||
```
|
||||
|
||||
Each demo runs a couple of mini-scenarios showcasing the APIs.
|
||||
|
||||
## Running benchmarks
|
||||
|
||||
To run the benchmarks, use:
|
||||
|
||||
```bash
|
||||
cargo bench --features "bench"
|
||||
```
|
||||
|
||||
## Building the user-guide
|
||||
|
||||
To build the user guide, use:
|
||||
|
||||
```bash
|
||||
cargo binstall mdbook
|
||||
bash docs/build.sh
|
||||
```
|
||||
|
||||
This will generate the user guide in the `docs/book` directory.
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
// Combined benchmarks for rustframe
|
||||
// Combined benchmarks
|
||||
use chrono::NaiveDate;
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
|
||||
use rustframe::{
|
||||
frame::{Frame, RowIndex},
|
||||
matrix::{BoolMatrix, Matrix},
|
||||
utils::{BDateFreq, BDatesList},
|
||||
matrix::{Axis, BoolMatrix, Matrix, SeriesOps},
|
||||
utils::{DateFreq, DatesList},
|
||||
};
|
||||
use chrono::NaiveDate;
|
||||
use std::time::Duration;
|
||||
|
||||
fn bool_matrix_operations_benchmark(c: &mut Criterion) {
|
||||
let sizes = [1, 100, 1000];
|
||||
// Define size categories
|
||||
const SIZES_SMALL: [usize; 1] = [1];
|
||||
const SIZES_MEDIUM: [usize; 3] = [100, 250, 500];
|
||||
const SIZES_LARGE: [usize; 1] = [1000];
|
||||
|
||||
for &size in &sizes {
|
||||
// Modified benchmark functions to accept a slice of sizes
|
||||
fn bool_matrix_operations_benchmark(c: &mut Criterion, sizes: &[usize]) {
|
||||
for &size in sizes {
|
||||
let data1: Vec<bool> = (0..size * size).map(|x| x % 2 == 0).collect();
|
||||
let data2: Vec<bool> = (0..size * size).map(|x| x % 3 == 0).collect();
|
||||
let bm1 = BoolMatrix::from_vec(data1.clone(), size, size);
|
||||
@@ -42,10 +48,8 @@ fn bool_matrix_operations_benchmark(c: &mut Criterion) {
|
||||
}
|
||||
}
|
||||
|
||||
fn matrix_boolean_operations_benchmark(c: &mut Criterion) {
|
||||
let sizes = [1, 100, 1000];
|
||||
|
||||
for &size in &sizes {
|
||||
fn matrix_boolean_operations_benchmark(c: &mut Criterion, sizes: &[usize]) {
|
||||
for &size in sizes {
|
||||
let data1: Vec<bool> = (0..size * size).map(|x| x % 2 == 0).collect();
|
||||
let data2: Vec<bool> = (0..size * size).map(|x| x % 3 == 0).collect();
|
||||
let bm1 = BoolMatrix::from_vec(data1.clone(), size, size);
|
||||
@@ -77,35 +81,8 @@ fn matrix_boolean_operations_benchmark(c: &mut Criterion) {
|
||||
}
|
||||
}
|
||||
|
||||
fn matrix_operations_benchmark(c: &mut Criterion) {
|
||||
let n_periods = 4;
|
||||
let dates: Vec<NaiveDate> =
|
||||
BDatesList::from_n_periods("2024-01-02".to_string(), BDateFreq::Daily, n_periods)
|
||||
.unwrap()
|
||||
.list()
|
||||
.unwrap();
|
||||
|
||||
let col_names: Vec<String> = vec!["a".to_string(), "b".to_string()];
|
||||
|
||||
let ma = Matrix::from_cols(vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]]);
|
||||
let mb = Matrix::from_cols(vec![vec![4.0, 3.0, 2.0, 1.0], vec![8.0, 7.0, 6.0, 5.0]]);
|
||||
|
||||
let fa = Frame::new(
|
||||
ma.clone(),
|
||||
col_names.clone(),
|
||||
Some(RowIndex::Date(dates.clone())),
|
||||
);
|
||||
let fb = Frame::new(mb, col_names, Some(RowIndex::Date(dates)));
|
||||
|
||||
c.bench_function("element-wise multiply", |b| {
|
||||
b.iter(|| {
|
||||
let _result = &fa * &fb;
|
||||
});
|
||||
});
|
||||
|
||||
let sizes = [1, 100, 1000];
|
||||
|
||||
for &size in &sizes {
|
||||
fn matrix_operations_benchmark(c: &mut Criterion, sizes: &[usize]) {
|
||||
for &size in sizes {
|
||||
let data: Vec<f64> = (0..size * size).map(|x| x as f64).collect();
|
||||
let ma = Matrix::from_vec(data.clone(), size, size);
|
||||
|
||||
@@ -132,8 +109,226 @@ fn matrix_operations_benchmark(c: &mut Criterion) {
|
||||
let _result = &ma / 2.0;
|
||||
});
|
||||
});
|
||||
c.bench_function(
|
||||
&format!("matrix matrix_multiply ({}x{})", size, size),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
let _result = ma.matrix_mul(&ma);
|
||||
});
|
||||
},
|
||||
);
|
||||
c.bench_function(&format!("matrix sum_horizontal ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = ma.sum_horizontal();
|
||||
});
|
||||
});
|
||||
c.bench_function(&format!("matrix sum_vertical ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = ma.sum_vertical();
|
||||
});
|
||||
});
|
||||
c.bench_function(
|
||||
&format!("matrix prod_horizontal ({}x{})", size, size),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
let _result = ma.prod_horizontal();
|
||||
});
|
||||
},
|
||||
);
|
||||
c.bench_function(&format!("matrix prod_vertical ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = ma.prod_vertical();
|
||||
});
|
||||
});
|
||||
c.bench_function(&format!("matrix apply_axis ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = ma.apply_axis(Axis::Col, |col| col.iter().sum::<f64>());
|
||||
});
|
||||
});
|
||||
c.bench_function(&format!("matrix transpose ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = ma.transpose();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
for &size in sizes {
|
||||
let data1: Vec<f64> = (0..size * size).map(|x| x as f64).collect();
|
||||
let data2: Vec<f64> = (0..size * size).map(|x| (x + 1) as f64).collect();
|
||||
let ma = Matrix::from_vec(data1.clone(), size, size);
|
||||
let mb = Matrix::from_vec(data2.clone(), size, size);
|
||||
|
||||
c.bench_function(&format!("matrix add ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = &ma + &mb;
|
||||
});
|
||||
});
|
||||
|
||||
c.bench_function(&format!("matrix subtract ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = &ma - &mb;
|
||||
});
|
||||
});
|
||||
|
||||
c.bench_function(&format!("matrix multiply ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = &ma * &mb;
|
||||
});
|
||||
});
|
||||
|
||||
c.bench_function(&format!("matrix divide ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = &ma / &mb;
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(combined_benches, bool_matrix_operations_benchmark, matrix_boolean_operations_benchmark, matrix_operations_benchmark);
|
||||
criterion_main!(combined_benches);
|
||||
fn generate_frame(size: usize) -> Frame<f64> {
|
||||
let data: Vec<f64> = (0..size * size).map(|x| x as f64).collect();
|
||||
let dates: Vec<NaiveDate> =
|
||||
DatesList::from_n_periods("2000-01-01".to_string(), DateFreq::Daily, size)
|
||||
.unwrap()
|
||||
.list()
|
||||
.unwrap();
|
||||
let col_names: Vec<String> = (1..=size).map(|i| format!("col_{}", i)).collect();
|
||||
Frame::new(
|
||||
Matrix::from_vec(data.clone(), size, size),
|
||||
col_names,
|
||||
Some(RowIndex::Date(dates)),
|
||||
)
|
||||
}
|
||||
|
||||
fn benchmark_frame_operations(c: &mut Criterion, sizes: &[usize]) {
|
||||
for &size in sizes {
|
||||
let fa = generate_frame(size);
|
||||
let fb = generate_frame(size);
|
||||
|
||||
c.bench_function(&format!("frame add ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = &fa + &fb;
|
||||
});
|
||||
});
|
||||
|
||||
c.bench_function(&format!("frame subtract ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = &fa - &fb;
|
||||
});
|
||||
});
|
||||
|
||||
c.bench_function(&format!("frame multiply ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = &fa * &fb;
|
||||
});
|
||||
});
|
||||
|
||||
c.bench_function(&format!("frame divide ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = &fa / &fb;
|
||||
});
|
||||
});
|
||||
|
||||
c.bench_function(&format!("frame matrix_multiply ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = fa.matrix_mul(&fb);
|
||||
});
|
||||
});
|
||||
|
||||
c.bench_function(&format!("frame sum_horizontal ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = fa.sum_horizontal();
|
||||
});
|
||||
});
|
||||
c.bench_function(&format!("frame sum_vertical ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = fa.sum_vertical();
|
||||
});
|
||||
});
|
||||
c.bench_function(&format!("frame prod_horizontal ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = fa.prod_horizontal();
|
||||
});
|
||||
});
|
||||
c.bench_function(&format!("frame prod_vertical ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = fa.prod_vertical();
|
||||
});
|
||||
});
|
||||
c.bench_function(&format!("frame apply_axis ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = fa.apply_axis(Axis::Col, |col| col.iter().sum::<f64>());
|
||||
});
|
||||
});
|
||||
c.bench_function(&format!("frame transpose ({}x{})", size, size), |b| {
|
||||
b.iter(|| {
|
||||
let _result = fa.transpose();
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Runner functions for each size category
|
||||
fn run_benchmarks_small(c: &mut Criterion) {
|
||||
bool_matrix_operations_benchmark(c, &SIZES_SMALL);
|
||||
matrix_boolean_operations_benchmark(c, &SIZES_SMALL);
|
||||
matrix_operations_benchmark(c, &SIZES_SMALL);
|
||||
benchmark_frame_operations(c, &SIZES_SMALL);
|
||||
}
|
||||
|
||||
fn run_benchmarks_medium(c: &mut Criterion) {
|
||||
bool_matrix_operations_benchmark(c, &SIZES_MEDIUM);
|
||||
matrix_boolean_operations_benchmark(c, &SIZES_MEDIUM);
|
||||
matrix_operations_benchmark(c, &SIZES_MEDIUM);
|
||||
benchmark_frame_operations(c, &SIZES_MEDIUM);
|
||||
}
|
||||
|
||||
fn run_benchmarks_large(c: &mut Criterion) {
|
||||
bool_matrix_operations_benchmark(c, &SIZES_LARGE);
|
||||
matrix_boolean_operations_benchmark(c, &SIZES_LARGE);
|
||||
matrix_operations_benchmark(c, &SIZES_LARGE);
|
||||
benchmark_frame_operations(c, &SIZES_LARGE);
|
||||
}
|
||||
|
||||
// Configuration functions for different size categories
|
||||
fn config_small_arrays() -> Criterion {
|
||||
Criterion::default()
|
||||
.sample_size(500)
|
||||
.measurement_time(Duration::from_millis(100))
|
||||
.warm_up_time(Duration::from_millis(5))
|
||||
}
|
||||
|
||||
fn config_medium_arrays() -> Criterion {
|
||||
Criterion::default()
|
||||
.sample_size(100)
|
||||
.measurement_time(Duration::from_millis(2000))
|
||||
.warm_up_time(Duration::from_millis(100))
|
||||
}
|
||||
|
||||
fn config_large_arrays() -> Criterion {
|
||||
Criterion::default()
|
||||
.sample_size(50)
|
||||
.measurement_time(Duration::from_millis(5000))
|
||||
.warm_up_time(Duration::from_millis(200))
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
name = benches_small_arrays;
|
||||
config = config_small_arrays();
|
||||
targets = run_benchmarks_small
|
||||
);
|
||||
criterion_group!(
|
||||
name = benches_medium_arrays;
|
||||
config = config_medium_arrays();
|
||||
targets = run_benchmarks_medium
|
||||
);
|
||||
criterion_group!(
|
||||
name = benches_large_arrays;
|
||||
config = config_large_arrays();
|
||||
targets = run_benchmarks_large
|
||||
);
|
||||
|
||||
criterion_main!(
|
||||
benches_small_arrays,
|
||||
benches_medium_arrays,
|
||||
benches_large_arrays
|
||||
);
|
||||
|
||||
7
docs/book.toml
Normal file
7
docs/book.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
[book]
|
||||
title = "Rustframe User Guide"
|
||||
authors = ["Palash Tyagi (https://github.com/Magnus167)"]
|
||||
description = "Guided journey through Rustframe capabilities."
|
||||
|
||||
[build]
|
||||
build-dir = "book"
|
||||
7
docs/build.sh
Executable file
7
docs/build.sh
Executable file
@@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env sh
|
||||
# Build and test the Rustframe user guide using mdBook.
|
||||
set -e
|
||||
|
||||
cd docs
|
||||
bash gen.sh "$@"
|
||||
cd ..
|
||||
14
docs/gen.sh
Normal file
14
docs/gen.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
set -e
|
||||
|
||||
cargo clean
|
||||
|
||||
cargo build --manifest-path ../Cargo.toml
|
||||
|
||||
mdbook test -L ../target/debug/deps "$@"
|
||||
|
||||
mdbook build "$@"
|
||||
|
||||
cargo build
|
||||
# cargo build --release
|
||||
7
docs/src/SUMMARY.md
Normal file
7
docs/src/SUMMARY.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Summary
|
||||
|
||||
- [Introduction](./introduction.md)
|
||||
- [Data Manipulation](./data-manipulation.md)
|
||||
- [Compute Features](./compute.md)
|
||||
- [Machine Learning](./machine-learning.md)
|
||||
- [Utilities](./utilities.md)
|
||||
222
docs/src/compute.md
Normal file
222
docs/src/compute.md
Normal file
@@ -0,0 +1,222 @@
|
||||
# Compute Features
|
||||
|
||||
The `compute` module hosts numerical routines for exploratory data analysis.
|
||||
It covers descriptive statistics, correlations, probability distributions and
|
||||
some basic inferential tests.
|
||||
|
||||
## Basic Statistics
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::stats::{mean, mean_horizontal, mean_vertical, stddev, median, population_variance, percentile};
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let m = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
assert_eq!(mean(&m), 2.5);
|
||||
assert_eq!(stddev(&m), 1.118033988749895);
|
||||
assert_eq!(median(&m), 2.5);
|
||||
assert_eq!(population_variance(&m), 1.25);
|
||||
assert_eq!(percentile(&m, 50.0), 3.0);
|
||||
// column averages returned as 1 x n matrix
|
||||
let row_means = mean_horizontal(&m);
|
||||
assert_eq!(row_means.data(), &[2.0, 3.0]);
|
||||
let col_means = mean_vertical(&m);
|
||||
assert_eq!(col_means.data(), & [1.5, 3.5]);
|
||||
```
|
||||
|
||||
### Axis-specific Operations
|
||||
|
||||
Operations can be applied along specific axes (rows or columns):
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::stats::{mean_vertical, mean_horizontal, stddev_vertical, stddev_horizontal};
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
// 3x2 matrix
|
||||
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
|
||||
|
||||
// Mean along columns (vertical) - returns 1 x cols matrix
|
||||
let col_means = mean_vertical(&m);
|
||||
assert_eq!(col_means.shape(), (1, 2));
|
||||
assert_eq!(col_means.data(), &[3.0, 4.0]); // [(1+3+5)/3, (2+4+6)/3]
|
||||
|
||||
// Mean along rows (horizontal) - returns rows x 1 matrix
|
||||
let row_means = mean_horizontal(&m);
|
||||
assert_eq!(row_means.shape(), (3, 1));
|
||||
assert_eq!(row_means.data(), &[1.5, 3.5, 5.5]); // [(1+2)/2, (3+4)/2, (5+6)/2]
|
||||
|
||||
// Standard deviation along columns
|
||||
let col_stddev = stddev_vertical(&m);
|
||||
assert_eq!(col_stddev.shape(), (1, 2));
|
||||
|
||||
// Standard deviation along rows
|
||||
let row_stddev = stddev_horizontal(&m);
|
||||
assert_eq!(row_stddev.shape(), (3, 1));
|
||||
```
|
||||
|
||||
## Correlation
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::stats::{pearson, covariance};
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let y = Matrix::from_vec(vec![2.0, 4.0, 6.0, 8.0], 2, 2);
|
||||
let corr = pearson(&x, &y);
|
||||
let cov = covariance(&x, &y);
|
||||
assert!((corr - 1.0).abs() < 1e-8);
|
||||
assert!((cov - 2.5).abs() < 1e-8);
|
||||
```
|
||||
|
||||
## Covariance
|
||||
|
||||
### `covariance`
|
||||
|
||||
Computes the population covariance between two equally sized matrices by flattening
|
||||
their values.
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::stats::covariance;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let y = Matrix::from_vec(vec![2.0, 4.0, 6.0, 8.0], 2, 2);
|
||||
let cov = covariance(&x, &y);
|
||||
assert!((cov - 2.5).abs() < 1e-8);
|
||||
```
|
||||
|
||||
### `covariance_vertical`
|
||||
|
||||
Evaluates covariance between columns (i.e. across rows) and returns a matrix of
|
||||
column pair covariances.
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::stats::covariance_vertical;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let cov = covariance_vertical(&m);
|
||||
assert_eq!(cov.shape(), (2, 2));
|
||||
assert!(cov.data().iter().all(|&v| (v - 1.0).abs() < 1e-8));
|
||||
```
|
||||
|
||||
### `covariance_horizontal`
|
||||
|
||||
Computes covariance between rows (i.e. across columns) returning a matrix that
|
||||
describes how each pair of rows varies together.
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::stats::covariance_horizontal;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let cov = covariance_horizontal(&m);
|
||||
assert_eq!(cov.shape(), (2, 2));
|
||||
assert!(cov.data().iter().all(|&v| (v - 0.25).abs() < 1e-8));
|
||||
```
|
||||
|
||||
### `covariance_matrix`
|
||||
|
||||
Builds a covariance matrix either between columns (`Axis::Col`) or rows
|
||||
(`Axis::Row`). Each entry represents how two series co-vary.
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::stats::covariance_matrix;
|
||||
use rustframe::matrix::{Axis, Matrix};
|
||||
|
||||
let data = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
|
||||
// Covariance between columns
|
||||
let cov_cols = covariance_matrix(&data, Axis::Col);
|
||||
assert!((cov_cols.get(0, 0) - 2.0).abs() < 1e-8);
|
||||
|
||||
// Covariance between rows
|
||||
let cov_rows = covariance_matrix(&data, Axis::Row);
|
||||
assert!((cov_rows.get(0, 1) + 0.5).abs() < 1e-8);
|
||||
```
|
||||
|
||||
## Distributions
|
||||
|
||||
Probability distribution helpers are available for common PDFs and CDFs.
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::stats::distributions::normal_pdf;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let x = Matrix::from_vec(vec![0.0, 1.0], 1, 2);
|
||||
let pdf = normal_pdf(x, 0.0, 1.0);
|
||||
assert_eq!(pdf.data().len(), 2);
|
||||
```
|
||||
|
||||
### Additional Distributions
|
||||
|
||||
Rustframe provides several other probability distributions:
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::stats::distributions::{normal_cdf, binomial_pmf, binomial_cdf, poisson_pmf};
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
// Normal distribution CDF
|
||||
let x = Matrix::from_vec(vec![0.0, 1.0], 1, 2);
|
||||
let cdf = normal_cdf(x, 0.0, 1.0);
|
||||
assert_eq!(cdf.data().len(), 2);
|
||||
|
||||
// Binomial distribution PMF
|
||||
// Probability of k successes in n trials with probability p
|
||||
let k = Matrix::from_vec(vec![0_u64, 1, 2, 3], 1, 4);
|
||||
let pmf = binomial_pmf(3, k.clone(), 0.5);
|
||||
assert_eq!(pmf.data().len(), 4);
|
||||
|
||||
// Binomial distribution CDF
|
||||
let cdf = binomial_cdf(3, k, 0.5);
|
||||
assert_eq!(cdf.data().len(), 4);
|
||||
|
||||
// Poisson distribution PMF
|
||||
// Probability of k events with rate parameter lambda
|
||||
let k = Matrix::from_vec(vec![0_u64, 1, 2], 1, 3);
|
||||
let pmf = poisson_pmf(2.0, k);
|
||||
assert_eq!(pmf.data().len(), 3);
|
||||
```
|
||||
|
||||
### Inferential Statistics
|
||||
|
||||
Rustframe provides several inferential statistical tests:
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::matrix::Matrix;
|
||||
use rustframe::compute::stats::inferential::{t_test, chi2_test, anova};
|
||||
|
||||
// Two-sample t-test
|
||||
let sample1 = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
|
||||
let sample2 = Matrix::from_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0], 1, 5);
|
||||
let (t_statistic, p_value) = t_test(&sample1, &sample2);
|
||||
assert!((t_statistic + 5.0).abs() < 1e-5);
|
||||
assert!(p_value > 0.0 && p_value < 1.0);
|
||||
|
||||
// Chi-square test of independence
|
||||
let observed = Matrix::from_vec(vec![12.0, 5.0, 8.0, 10.0], 2, 2);
|
||||
let (chi2_statistic, p_value) = chi2_test(&observed);
|
||||
assert!(chi2_statistic > 0.0);
|
||||
assert!(p_value > 0.0 && p_value < 1.0);
|
||||
|
||||
// One-way ANOVA
|
||||
let group1 = Matrix::from_vec(vec![1.0, 2.0, 3.0], 1, 3);
|
||||
let group2 = Matrix::from_vec(vec![2.0, 3.0, 4.0], 1, 3);
|
||||
let group3 = Matrix::from_vec(vec![3.0, 4.0, 5.0], 1, 3);
|
||||
let groups = vec![&group1, &group2, &group3];
|
||||
let (f_statistic, p_value) = anova(groups);
|
||||
assert!(f_statistic > 0.0);
|
||||
assert!(p_value > 0.0 && p_value < 1.0);
|
||||
```
|
||||
|
||||
With the basics covered, explore predictive models in the
|
||||
[machine learning](./machine-learning.md) chapter.
|
||||
157
docs/src/data-manipulation.md
Normal file
157
docs/src/data-manipulation.md
Normal file
@@ -0,0 +1,157 @@
|
||||
# Data Manipulation
|
||||
|
||||
Rustframe's `Frame` type couples tabular data with
|
||||
column labels and a typed row index. Frames expose a familiar API for loading
|
||||
data, selecting rows or columns and performing aggregations.
|
||||
|
||||
## Creating a Frame
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::frame::{Frame, RowIndex};
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let data = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
||||
let frame = Frame::new(data, vec!["A", "B"], None);
|
||||
assert_eq!(frame["A"], vec![1.0, 2.0]);
|
||||
```
|
||||
|
||||
## Indexing Rows
|
||||
|
||||
Row labels can be integers, dates or a default range. Retrieving a row returns a
|
||||
view that lets you inspect values by column name or position.
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
# extern crate chrono;
|
||||
use chrono::NaiveDate;
|
||||
use rustframe::frame::{Frame, RowIndex};
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let d = |y, m, d| NaiveDate::from_ymd_opt(y, m, d).unwrap();
|
||||
let data = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
||||
let index = RowIndex::Date(vec![d(2024, 1, 1), d(2024, 1, 2)]);
|
||||
let mut frame = Frame::new(data, vec!["A", "B"], Some(index));
|
||||
assert_eq!(frame.get_row_date(d(2024, 1, 2))["B"], 4.0);
|
||||
|
||||
// mutate by row key
|
||||
frame.get_row_date_mut(d(2024, 1, 1)).set_by_index(0, 9.0);
|
||||
assert_eq!(frame.get_row_date(d(2024, 1, 1))["A"], 9.0);
|
||||
```
|
||||
|
||||
## Column operations
|
||||
|
||||
Columns can be inserted, renamed, removed or reordered in place.
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::frame::{Frame, RowIndex};
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let data = Matrix::from_cols(vec![vec![1, 2], vec![3, 4]]);
|
||||
let mut frame = Frame::new(data, vec!["X", "Y"], Some(RowIndex::Range(0..2)));
|
||||
|
||||
frame.add_column("Z", vec![5, 6]);
|
||||
frame.rename("Y", "W");
|
||||
let removed = frame.delete_column("X");
|
||||
assert_eq!(removed, vec![1, 2]);
|
||||
frame.sort_columns();
|
||||
assert_eq!(frame.columns(), &["W", "Z"]);
|
||||
```
|
||||
|
||||
## Aggregations
|
||||
|
||||
Any numeric aggregation available on `Matrix` is forwarded to `Frame`.
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::frame::Frame;
|
||||
use rustframe::matrix::{Matrix, SeriesOps};
|
||||
|
||||
let frame = Frame::new(Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]), vec!["A", "B"], None);
|
||||
assert_eq!(frame.sum_vertical(), vec![3.0, 7.0]);
|
||||
assert_eq!(frame.sum_horizontal(), vec![4.0, 6.0]);
|
||||
```
|
||||
|
||||
## Matrix Operations
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let data1 = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let data2 = Matrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2);
|
||||
|
||||
let sum = data1.clone() + data2.clone();
|
||||
assert_eq!(sum.data(), vec![6.0, 8.0, 10.0, 12.0]);
|
||||
|
||||
let product = data1.clone() * data2.clone();
|
||||
assert_eq!(product.data(), vec![5.0, 12.0, 21.0, 32.0]);
|
||||
|
||||
let scalar_product = data1.clone() * 2.0;
|
||||
assert_eq!(scalar_product.data(), vec![2.0, 4.0, 6.0, 8.0]);
|
||||
|
||||
let equals = data1 == data1.clone();
|
||||
assert_eq!(equals, true);
|
||||
```
|
||||
|
||||
### Advanced Matrix Operations
|
||||
|
||||
Matrices support a variety of advanced operations:
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::matrix::{Matrix, SeriesOps};
|
||||
|
||||
// Matrix multiplication (dot product)
|
||||
let a = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let b = Matrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2);
|
||||
let product = a.matrix_mul(&b);
|
||||
assert_eq!(product.data(), vec![23.0, 34.0, 31.0, 46.0]);
|
||||
|
||||
// Transpose
|
||||
let m = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let transposed = m.transpose();
|
||||
assert_eq!(transposed.data(), vec![1.0, 3.0, 2.0, 4.0]);
|
||||
|
||||
// Map function over all elements
|
||||
let m = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let squared = m.map(|x| x * x);
|
||||
assert_eq!(squared.data(), vec![1.0, 4.0, 9.0, 16.0]);
|
||||
|
||||
// Zip two matrices with a function
|
||||
let a = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let b = Matrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2);
|
||||
let zipped = a.zip(&b, |x, y| x + y);
|
||||
assert_eq!(zipped.data(), vec![6.0, 8.0, 10.0, 12.0]);
|
||||
```
|
||||
|
||||
### Matrix Reductions
|
||||
|
||||
Matrices support various reduction operations:
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::matrix::{Matrix, SeriesOps};
|
||||
|
||||
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
|
||||
|
||||
// Sum along columns (vertical)
|
||||
let col_sums = m.sum_vertical();
|
||||
assert_eq!(col_sums, vec![9.0, 12.0]); // [1+3+5, 2+4+6]
|
||||
|
||||
// Sum along rows (horizontal)
|
||||
let row_sums = m.sum_horizontal();
|
||||
assert_eq!(row_sums, vec![3.0, 7.0, 11.0]); // [1+2, 3+4, 5+6]
|
||||
|
||||
// Cumulative sum along columns
|
||||
let col_cumsum = m.cumsum_vertical();
|
||||
assert_eq!(col_cumsum.data(), vec![1.0, 4.0, 9.0, 2.0, 6.0, 12.0]);
|
||||
|
||||
// Cumulative sum along rows
|
||||
let row_cumsum = m.cumsum_horizontal();
|
||||
assert_eq!(row_cumsum.data(), vec![1.0, 3.0, 5.0, 3.0, 7.0, 11.0]);
|
||||
```
|
||||
|
||||
With the basics covered, continue to the [compute features](./compute.md)
|
||||
chapter for statistics and analytics.
|
||||
40
docs/src/introduction.md
Normal file
40
docs/src/introduction.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Introduction
|
||||
|
||||
🐙 [GitHub](https://github.com/Magnus167/rustframe) | 📚 [Docs](https://magnus167.github.io/rustframe/) | 📖 [User Guide](https://magnus167.github.io/rustframe/user-guide/) | 🦀 [Crates.io](https://crates.io/crates/rustframe) | 🔖 [docs.rs](https://docs.rs/rustframe/latest/rustframe/)
|
||||
|
||||
Welcome to the **Rustframe User Guide**. Rustframe is a lightweight dataframe
|
||||
and math toolkit for Rust written in 100% safe Rust. It focuses on keeping the
|
||||
API approachable while offering handy features for small analytical or
|
||||
educational projects.
|
||||
|
||||
Rustframe bundles:
|
||||
|
||||
- column‑labelled frames built on a fast column‑major matrix
|
||||
- familiar element‑wise math and aggregation routines
|
||||
- a growing `compute` module for statistics and machine learning
|
||||
- utilities for dates and random numbers
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::{frame::Frame, matrix::{Matrix, SeriesOps}};
|
||||
|
||||
let data = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
||||
let frame = Frame::new(data, vec!["A", "B"], None);
|
||||
|
||||
// Perform column wise aggregation
|
||||
assert_eq!(frame.sum_vertical(), vec![3.0, 7.0]);
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- [GitHub repository](https://github.com/Magnus167/rustframe)
|
||||
- [Crates.io](https://crates.io/crates/rustframe) & [API docs](https://docs.rs/rustframe)
|
||||
- [Code coverage](https://codecov.io/gh/Magnus167/rustframe)
|
||||
|
||||
This guide walks through the main building blocks of the library. Each chapter
|
||||
contains runnable snippets so you can follow along:
|
||||
|
||||
1. [Data manipulation](./data-manipulation.md) for loading and transforming data
|
||||
2. [Compute features](./compute.md) for statistics and analytics
|
||||
3. [Machine learning](./machine-learning.md) for predictive models
|
||||
4. [Utilities](./utilities.md) for supporting helpers and upcoming modules
|
||||
282
docs/src/machine-learning.md
Normal file
282
docs/src/machine-learning.md
Normal file
@@ -0,0 +1,282 @@
|
||||
# Machine Learning
|
||||
|
||||
The `compute::models` module bundles several learning algorithms that operate on
|
||||
`Matrix` structures. These examples highlight the basic training and prediction
|
||||
APIs. For more end‑to‑end walkthroughs see the examples directory in the
|
||||
repository.
|
||||
|
||||
Currently implemented models include:
|
||||
|
||||
- Linear and logistic regression
|
||||
- K‑means clustering
|
||||
- Principal component analysis (PCA)
|
||||
- Gaussian Naive Bayes
|
||||
- Dense neural networks
|
||||
|
||||
## Linear Regression
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::models::linreg::LinReg;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
|
||||
let y = Matrix::from_vec(vec![2.0, 3.0, 4.0, 5.0], 4, 1);
|
||||
let mut model = LinReg::new(1);
|
||||
model.fit(&x, &y, 0.01, 100);
|
||||
let preds = model.predict(&x);
|
||||
assert_eq!(preds.rows(), 4);
|
||||
```
|
||||
|
||||
## K-means Walkthrough
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::models::k_means::KMeans;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let data = Matrix::from_vec(vec![1.0, 1.0, 5.0, 5.0], 2, 2);
|
||||
let (model, _labels) = KMeans::fit(&data, 2, 10, 1e-4);
|
||||
let new_point = Matrix::from_vec(vec![0.0, 0.0], 1, 2);
|
||||
let cluster = model.predict(&new_point)[0];
|
||||
```
|
||||
|
||||
## Logistic Regression
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::models::logreg::LogReg;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
|
||||
let y = Matrix::from_vec(vec![0.0, 0.0, 1.0, 1.0], 4, 1);
|
||||
let mut model = LogReg::new(1);
|
||||
model.fit(&x, &y, 0.1, 200);
|
||||
let preds = model.predict_proba(&x);
|
||||
assert_eq!(preds.rows(), 4);
|
||||
```
|
||||
|
||||
## Principal Component Analysis
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::models::pca::PCA;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
let data = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let pca = PCA::fit(&data, 1, 0);
|
||||
let transformed = pca.transform(&data);
|
||||
assert_eq!(transformed.cols(), 1);
|
||||
```
|
||||
|
||||
## Gaussian Naive Bayes
|
||||
|
||||
Gaussian Naive Bayes classifier for continuous features:
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::models::gaussian_nb::GaussianNB;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
// Training data with 2 features
|
||||
let x = Matrix::from_rows_vec(vec![
|
||||
1.0, 2.0,
|
||||
2.0, 3.0,
|
||||
3.0, 4.0,
|
||||
4.0, 5.0
|
||||
], 4, 2);
|
||||
|
||||
// Class labels (0 or 1)
|
||||
let y = Matrix::from_vec(vec![0.0, 0.0, 1.0, 1.0], 4, 1);
|
||||
|
||||
// Train the model
|
||||
let mut model = GaussianNB::new(1e-9, true);
|
||||
model.fit(&x, &y);
|
||||
|
||||
// Make predictions
|
||||
let predictions = model.predict(&x);
|
||||
assert_eq!(predictions.rows(), 4);
|
||||
```
|
||||
|
||||
## Dense Neural Networks
|
||||
|
||||
Simple fully connected neural network:
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::models::dense_nn::{DenseNN, DenseNNConfig, ActivationKind, InitializerKind, LossKind};
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
// Training data with 2 features
|
||||
let x = Matrix::from_rows_vec(vec![
|
||||
0.0, 0.0,
|
||||
0.0, 1.0,
|
||||
1.0, 0.0,
|
||||
1.0, 1.0
|
||||
], 4, 2);
|
||||
|
||||
// XOR target outputs
|
||||
let y = Matrix::from_vec(vec![0.0, 1.0, 1.0, 0.0], 4, 1);
|
||||
|
||||
// Create a neural network with 2 hidden layers
|
||||
let config = DenseNNConfig {
|
||||
input_size: 2,
|
||||
hidden_layers: vec![4, 4],
|
||||
output_size: 1,
|
||||
activations: vec![ActivationKind::Sigmoid, ActivationKind::Sigmoid, ActivationKind::Sigmoid],
|
||||
initializer: InitializerKind::Uniform(0.5),
|
||||
loss: LossKind::MSE,
|
||||
learning_rate: 0.1,
|
||||
epochs: 1000,
|
||||
};
|
||||
let mut model = DenseNN::new(config);
|
||||
|
||||
// Train the model
|
||||
model.train(&x, &y);
|
||||
|
||||
// Make predictions
|
||||
let predictions = model.predict(&x);
|
||||
assert_eq!(predictions.rows(), 4);
|
||||
```
|
||||
|
||||
## Real-world Examples
|
||||
|
||||
### Housing Price Prediction
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::models::linreg::LinReg;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
// Features: square feet and bedrooms
|
||||
let features = Matrix::from_rows_vec(vec![
|
||||
2100.0, 3.0,
|
||||
1600.0, 2.0,
|
||||
2400.0, 4.0,
|
||||
1400.0, 2.0,
|
||||
], 4, 2);
|
||||
|
||||
// Sale prices
|
||||
let target = Matrix::from_vec(vec![400_000.0, 330_000.0, 369_000.0, 232_000.0], 4, 1);
|
||||
|
||||
let mut model = LinReg::new(2);
|
||||
model.fit(&features, &target, 1e-8, 10_000);
|
||||
|
||||
// Predict price of a new home
|
||||
let new_home = Matrix::from_vec(vec![2000.0, 3.0], 1, 2);
|
||||
let predicted_price = model.predict(&new_home);
|
||||
println!("Predicted price: ${}", predicted_price.data()[0]);
|
||||
```
|
||||
|
||||
### Spam Detection
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::models::logreg::LogReg;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
// 20 e-mails × 5 features = 100 numbers (row-major, spam first)
|
||||
let x = Matrix::from_rows_vec(
|
||||
vec![
|
||||
// ─────────── spam examples ───────────
|
||||
2.0, 1.0, 1.0, 1.0, 1.0, // "You win a FREE offer - click for money-back bonus!"
|
||||
1.0, 0.0, 1.0, 1.0, 0.0, // "FREE offer! Click now!"
|
||||
0.0, 2.0, 0.0, 1.0, 1.0, // "Win win win - money inside, click…"
|
||||
1.0, 1.0, 0.0, 0.0, 1.0, // "Limited offer to win easy money…"
|
||||
1.0, 0.0, 1.0, 0.0, 1.0, // ...
|
||||
0.0, 1.0, 1.0, 1.0, 0.0, // ...
|
||||
2.0, 0.0, 0.0, 1.0, 1.0, // ...
|
||||
0.0, 1.0, 1.0, 0.0, 1.0, // ...
|
||||
1.0, 1.0, 1.0, 1.0, 0.0, // ...
|
||||
1.0, 0.0, 0.0, 1.0, 1.0, // ...
|
||||
// ─────────── ham examples ───────────
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, // "See you at the meeting tomorrow."
|
||||
0.0, 0.0, 0.0, 1.0, 0.0, // "Here's the Zoom click-link."
|
||||
0.0, 0.0, 0.0, 0.0, 1.0, // "Expense report: money attached."
|
||||
0.0, 0.0, 0.0, 1.0, 1.0, // ...
|
||||
0.0, 1.0, 0.0, 0.0, 0.0, // "Did we win the bid?"
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, // ...
|
||||
0.0, 0.0, 0.0, 1.0, 0.0, // ...
|
||||
1.0, 0.0, 0.0, 0.0, 0.0, // "Special offer for staff lunch."
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, // ...
|
||||
0.0, 0.0, 0.0, 1.0, 0.0,
|
||||
],
|
||||
20,
|
||||
5,
|
||||
);
|
||||
|
||||
// Labels: 1 = spam, 0 = ham
|
||||
let y = Matrix::from_vec(
|
||||
vec![
|
||||
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, // 10 spam
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, // 10 ham
|
||||
],
|
||||
20,
|
||||
1,
|
||||
);
|
||||
|
||||
// Train
|
||||
let mut model = LogReg::new(5);
|
||||
model.fit(&x, &y, 0.01, 5000);
|
||||
|
||||
// Predict
|
||||
// e.g. "free money offer"
|
||||
let email_data = vec![1.0, 0.0, 1.0, 0.0, 1.0];
|
||||
let email = Matrix::from_vec(email_data, 1, 5);
|
||||
let prob_spam = model.predict_proba(&email);
|
||||
println!("Probability of spam: {:.4}", prob_spam.data()[0]);
|
||||
```
|
||||
|
||||
### Iris Flower Classification
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::models::gaussian_nb::GaussianNB;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
// Features: sepal length and petal length
|
||||
let x = Matrix::from_rows_vec(vec![
|
||||
5.1, 1.4, // setosa
|
||||
4.9, 1.4, // setosa
|
||||
6.2, 4.5, // versicolor
|
||||
5.9, 5.1, // virginica
|
||||
], 4, 2);
|
||||
|
||||
let y = Matrix::from_vec(vec![0.0, 0.0, 1.0, 2.0], 4, 1);
|
||||
let names = vec!["setosa", "versicolor", "virginica"];
|
||||
|
||||
let mut model = GaussianNB::new(1e-9, true);
|
||||
model.fit(&x, &y);
|
||||
|
||||
let sample = Matrix::from_vec(vec![5.0, 1.5], 1, 2);
|
||||
let predicted_class = model.predict(&sample);
|
||||
let class_name = names[predicted_class.data()[0] as usize];
|
||||
println!("Predicted class: {} ({:?})", class_name, predicted_class.data()[0]);
|
||||
```
|
||||
|
||||
### Customer Segmentation
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::compute::models::k_means::KMeans;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
// Each row: [age, annual_income]
|
||||
let customers = Matrix::from_rows_vec(
|
||||
vec![
|
||||
25.0, 40_000.0, 34.0, 52_000.0, 58.0, 95_000.0, 45.0, 70_000.0,
|
||||
],
|
||||
4,
|
||||
2,
|
||||
);
|
||||
|
||||
let (model, labels) = KMeans::fit(&customers, 2, 20, 1e-4);
|
||||
|
||||
let new_customer = Matrix::from_vec(vec![30.0, 50_000.0], 1, 2);
|
||||
let cluster = model.predict(&new_customer)[0];
|
||||
println!("New customer belongs to cluster: {}", cluster);
|
||||
println!("Cluster labels: {:?}", labels);
|
||||
```
|
||||
|
||||
For helper functions and upcoming modules, visit the
|
||||
[utilities](./utilities.md) section.
|
||||
63
docs/src/utilities.md
Normal file
63
docs/src/utilities.md
Normal file
@@ -0,0 +1,63 @@
|
||||
# Utilities
|
||||
|
||||
Utilities provide handy helpers around the core library. Existing tools
|
||||
include:
|
||||
|
||||
- Date utilities for generating calendar sequences and business‑day sets
|
||||
- Random number generators for simulations and testing
|
||||
|
||||
## Date Helpers
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::utils::dateutils::{BDatesList, BDateFreq, DatesList, DateFreq};
|
||||
|
||||
// Calendar sequence
|
||||
let list = DatesList::new("2024-01-01".into(), "2024-01-03".into(), DateFreq::Daily);
|
||||
assert_eq!(list.count().unwrap(), 3);
|
||||
|
||||
// Business days starting from 2024‑01‑02
|
||||
let bdates = BDatesList::from_n_periods("2024-01-02".into(), BDateFreq::Daily, 3).unwrap();
|
||||
assert_eq!(bdates.list().unwrap().len(), 3);
|
||||
```
|
||||
|
||||
## Random Numbers
|
||||
|
||||
The `random` module offers deterministic and cryptographically secure RNGs.
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::random::{Prng, Rng};
|
||||
|
||||
let mut rng = Prng::new(42);
|
||||
let v1 = rng.next_u64();
|
||||
let v2 = rng.next_u64();
|
||||
assert_ne!(v1, v2);
|
||||
```
|
||||
|
||||
## Stats Functions
|
||||
|
||||
```rust
|
||||
# extern crate rustframe;
|
||||
use rustframe::matrix::Matrix;
|
||||
use rustframe::compute::stats::descriptive::{mean, median, stddev};
|
||||
|
||||
let data = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
|
||||
|
||||
let mean_value = mean(&data);
|
||||
assert_eq!(mean_value, 3.0);
|
||||
|
||||
let median_value = median(&data);
|
||||
assert_eq!(median_value, 3.0);
|
||||
|
||||
let std_value = stddev(&data);
|
||||
assert_eq!(std_value, 2.0_f64.sqrt());
|
||||
```
|
||||
|
||||
Upcoming utilities will cover:
|
||||
|
||||
- Data import/export helpers
|
||||
- Visualization adapters
|
||||
- Streaming data interfaces
|
||||
|
||||
Contributions to these sections are welcome!
|
||||
45
examples/correlation.rs
Normal file
45
examples/correlation.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use rustframe::compute::stats::{covariance, covariance_matrix, pearson};
|
||||
use rustframe::matrix::{Axis, Matrix};
|
||||
|
||||
/// Demonstrates covariance and correlation utilities.
|
||||
fn main() {
|
||||
pairwise_cov();
|
||||
println!("\n-----\n");
|
||||
matrix_cov();
|
||||
}
|
||||
|
||||
fn pairwise_cov() {
|
||||
println!("Covariance & Pearson r\n----------------------");
|
||||
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let y = Matrix::from_vec(vec![1.0, 2.0, 3.0, 5.0], 2, 2);
|
||||
println!("covariance : {:.2}", covariance(&x, &y));
|
||||
println!("pearson r : {:.3}", pearson(&x, &y));
|
||||
}
|
||||
|
||||
fn matrix_cov() {
|
||||
println!("Covariance matrix\n-----------------");
|
||||
let data = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let cov = covariance_matrix(&data, Axis::Col);
|
||||
println!("cov matrix : {:?}", cov.data());
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
const EPS: f64 = 1e-8;
|
||||
|
||||
#[test]
|
||||
fn test_pairwise_cov() {
|
||||
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let y = Matrix::from_vec(vec![1.0, 2.0, 3.0, 5.0], 2, 2);
|
||||
assert!((covariance(&x, &y) - 1.625).abs() < EPS);
|
||||
assert!((pearson(&x, &y) - 0.9827076298239908).abs() < 1e-5,);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matrix_cov() {
|
||||
let data = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let cov = covariance_matrix(&data, Axis::Col);
|
||||
assert_eq!(cov.data(), &[2.0, 2.0, 2.0, 2.0]);
|
||||
}
|
||||
}
|
||||
56
examples/descriptive_stats.rs
Normal file
56
examples/descriptive_stats.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use rustframe::compute::stats::{mean, mean_horizontal, mean_vertical, median, percentile, stddev};
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
/// Demonstrates descriptive statistics utilities.
|
||||
///
|
||||
/// Part 1: simple mean/stddev/median/percentile on a vector.
|
||||
/// Part 2: mean across rows and columns.
|
||||
fn main() {
|
||||
simple_stats();
|
||||
println!("\n-----\n");
|
||||
axis_stats();
|
||||
}
|
||||
|
||||
fn simple_stats() {
|
||||
println!("Basic stats\n-----------");
|
||||
let data = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
|
||||
println!("mean : {:.2}", mean(&data));
|
||||
println!("stddev : {:.2}", stddev(&data));
|
||||
println!("median : {:.2}", median(&data));
|
||||
println!("90th pct. : {:.2}", percentile(&data, 90.0));
|
||||
}
|
||||
|
||||
fn axis_stats() {
|
||||
println!("Row/column means\n----------------");
|
||||
// 2x3 matrix
|
||||
let data = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
|
||||
let v = mean_vertical(&data); // 1x3
|
||||
let h = mean_horizontal(&data); // 2x1
|
||||
println!("vertical means : {:?}", v.data());
|
||||
println!("horizontal means: {:?}", h.data());
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
const EPS: f64 = 1e-8;
|
||||
|
||||
#[test]
|
||||
fn test_simple_stats() {
|
||||
let data = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
|
||||
assert!((mean(&data) - 3.0).abs() < EPS);
|
||||
assert!((stddev(&data) - 1.4142135623730951).abs() < EPS);
|
||||
assert!((median(&data) - 3.0).abs() < EPS);
|
||||
assert!((percentile(&data, 90.0) - 5.0).abs() < EPS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_axis_stats() {
|
||||
let data = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
|
||||
let v = mean_vertical(&data);
|
||||
assert_eq!(v.data(), &[2.5, 3.5, 4.5]);
|
||||
let h = mean_horizontal(&data);
|
||||
assert_eq!(h.data(), &[2.0, 5.0]);
|
||||
}
|
||||
}
|
||||
|
||||
66
examples/distributions.rs
Normal file
66
examples/distributions.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use rustframe::compute::stats::{binomial_cdf, binomial_pmf, normal_cdf, normal_pdf, poisson_pmf};
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
/// Demonstrates some probability distribution helpers.
|
||||
fn main() {
|
||||
normal_example();
|
||||
println!("\n-----\n");
|
||||
binomial_example();
|
||||
println!("\n-----\n");
|
||||
poisson_example();
|
||||
}
|
||||
|
||||
fn normal_example() {
|
||||
println!("Normal distribution\n-------------------");
|
||||
let x = Matrix::from_vec(vec![0.0, 1.0], 1, 2);
|
||||
let pdf = normal_pdf(x.clone(), 0.0, 1.0);
|
||||
let cdf = normal_cdf(x, 0.0, 1.0);
|
||||
println!("pdf : {:?}", pdf.data());
|
||||
println!("cdf : {:?}", cdf.data());
|
||||
}
|
||||
|
||||
fn binomial_example() {
|
||||
println!("Binomial distribution\n---------------------");
|
||||
let k = Matrix::from_vec(vec![0_u64, 1, 2], 1, 3);
|
||||
let pmf = binomial_pmf(4, k.clone(), 0.5);
|
||||
let cdf = binomial_cdf(4, k, 0.5);
|
||||
println!("pmf : {:?}", pmf.data());
|
||||
println!("cdf : {:?}", cdf.data());
|
||||
}
|
||||
|
||||
fn poisson_example() {
|
||||
println!("Poisson distribution\n--------------------");
|
||||
let k = Matrix::from_vec(vec![0_u64, 1, 2], 1, 3);
|
||||
let pmf = poisson_pmf(3.0, k);
|
||||
println!("pmf : {:?}", pmf.data());
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_normal_example() {
|
||||
let x = Matrix::from_vec(vec![0.0, 1.0], 1, 2);
|
||||
let pdf = normal_pdf(x.clone(), 0.0, 1.0);
|
||||
let cdf = normal_cdf(x, 0.0, 1.0);
|
||||
assert!((pdf.get(0, 0) - 0.39894228).abs() < 1e-6);
|
||||
assert!((cdf.get(0, 1) - 0.8413447).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binomial_example() {
|
||||
let k = Matrix::from_vec(vec![0_u64, 1, 2], 1, 3);
|
||||
let pmf = binomial_pmf(4, k.clone(), 0.5);
|
||||
let cdf = binomial_cdf(4, k, 0.5);
|
||||
assert!((pmf.get(0, 2) - 0.375).abs() < 1e-6);
|
||||
assert!((cdf.get(0, 2) - 0.6875).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_poisson_example() {
|
||||
let k = Matrix::from_vec(vec![0_u64, 1, 2], 1, 3);
|
||||
let pmf = poisson_pmf(3.0, k);
|
||||
assert!((pmf.get(0, 1) - 3.0_f64 * (-3.0_f64).exp()).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
413
examples/game_of_life.rs
Normal file
413
examples/game_of_life.rs
Normal file
@@ -0,0 +1,413 @@
|
||||
//! Conway's Game of Life Example
|
||||
//! This example implements Conway's Game of Life using a `BoolMatrix` to represent the game board.
|
||||
//! It demonstrates matrix operations like shifting, counting neighbors, and applying game rules.
|
||||
//! The game runs in a loop, updating the board state and printing it to the console.
|
||||
//! To modify the behaviour of the example, please change the constants at the top of this file.
|
||||
|
||||
|
||||
use rustframe::matrix::{BoolMatrix, BoolOps, IntMatrix, Matrix};
|
||||
use rustframe::random::{rng, Rng};
|
||||
use std::{thread, time};
|
||||
|
||||
const BOARD_SIZE: usize = 20; // Size of the board (50x50)
|
||||
const MAX_FRAMES: u32 = 1000;
|
||||
|
||||
const TICK_DURATION_MS: u64 = 0; // Milliseconds per frame
|
||||
const SKIP_FRAMES: u32 = 1;
|
||||
const PRINT_BOARD: bool = true; // Set to false to disable printing the board
|
||||
|
||||
fn main() {
|
||||
let args = std::env::args().collect::<Vec<String>>();
|
||||
let debug_mode = args.contains(&"--debug".to_string());
|
||||
let print_mode = if debug_mode { false } else { PRINT_BOARD };
|
||||
|
||||
let mut current_board =
|
||||
BoolMatrix::from_vec(vec![false; BOARD_SIZE * BOARD_SIZE], BOARD_SIZE, BOARD_SIZE);
|
||||
|
||||
let primes = generate_primes((BOARD_SIZE * BOARD_SIZE) as i32);
|
||||
|
||||
add_simulated_activity(&mut current_board, BOARD_SIZE);
|
||||
|
||||
let mut generation_count: u32 = 0;
|
||||
let mut previous_board_state: Option<BoolMatrix> = None;
|
||||
let mut board_hashes = Vec::new();
|
||||
let mut print_bool_int = 0;
|
||||
|
||||
loop {
|
||||
if print_bool_int % SKIP_FRAMES == 0 {
|
||||
print_board(¤t_board, generation_count, print_mode);
|
||||
|
||||
print_bool_int = 0;
|
||||
} else {
|
||||
print_bool_int += 1;
|
||||
}
|
||||
board_hashes.push(hash_board(¤t_board, primes.clone()));
|
||||
if detect_stable_state(¤t_board, &previous_board_state) {
|
||||
println!(
|
||||
"\nStable state detected at generation {}.",
|
||||
generation_count
|
||||
);
|
||||
add_simulated_activity(&mut current_board, BOARD_SIZE);
|
||||
}
|
||||
if detect_repeating_state(&mut board_hashes) {
|
||||
println!(
|
||||
"\nRepeating state detected at generation {}.",
|
||||
generation_count
|
||||
);
|
||||
add_simulated_activity(&mut current_board, BOARD_SIZE);
|
||||
}
|
||||
if !¤t_board.any() {
|
||||
println!("\nExtinction at generation {}.", generation_count);
|
||||
add_simulated_activity(&mut current_board, BOARD_SIZE);
|
||||
}
|
||||
|
||||
previous_board_state = Some(current_board.clone());
|
||||
|
||||
let next_board = game_of_life_next_frame(¤t_board);
|
||||
current_board = next_board;
|
||||
|
||||
generation_count += 1;
|
||||
thread::sleep(time::Duration::from_millis(TICK_DURATION_MS));
|
||||
|
||||
if (MAX_FRAMES > 0) && (generation_count > MAX_FRAMES) {
|
||||
println!("\nReached generation limit.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Prints the Game of Life board to the console.
|
||||
///
|
||||
/// - `board`: A reference to the `BoolMatrix` representing the current game state.
|
||||
/// This function demonstrates `board.rows()`, `board.cols()`, and `board[(r, c)]` (Index trait).
|
||||
fn print_board(board: &BoolMatrix, generation_count: u32, print_mode: bool) {
|
||||
if !print_mode {
|
||||
return;
|
||||
}
|
||||
|
||||
print!("{}[2J", 27 as char);
|
||||
println!("Conway's Game of Life - Generation: {}", generation_count);
|
||||
let mut print_str = String::new();
|
||||
print_str.push_str("+");
|
||||
for _ in 0..board.cols() {
|
||||
print_str.push_str("--");
|
||||
}
|
||||
print_str.push_str("+\n");
|
||||
for r in 0..board.rows() {
|
||||
print_str.push_str("| ");
|
||||
for c in 0..board.cols() {
|
||||
if board[(r, c)] {
|
||||
print_str.push_str("██");
|
||||
} else {
|
||||
print_str.push_str(" ");
|
||||
}
|
||||
}
|
||||
print_str.push_str(" |\n");
|
||||
}
|
||||
print_str.push_str("+");
|
||||
for _ in 0..board.cols() {
|
||||
print_str.push_str("--");
|
||||
}
|
||||
print_str.push_str("+\n\n");
|
||||
print!("{}", print_str);
|
||||
|
||||
println!("Alive cells: {}", board.count());
|
||||
}
|
||||
|
||||
/// Helper function to create a shifted version of the game board.
|
||||
/// (Using the version provided by the user)
|
||||
///
|
||||
/// - `game`: The current state of the Game of Life as a `BoolMatrix`.
|
||||
/// - `dr`: The row shift (delta row). Positive shifts down, negative shifts up.
|
||||
/// - `dc`: The column shift (delta column). Positive shifts right, negative shifts left.
|
||||
///
|
||||
/// Returns an `IntMatrix` of the same dimensions as `game`.
|
||||
/// - Cells in the shifted matrix get value `1` if the corresponding source cell in `game` was `true` (alive).
|
||||
/// - Cells that would source from outside `game`'s bounds (due to the shift) get value `0`.
|
||||
fn get_shifted_neighbor_layer(game: &BoolMatrix, dr: isize, dc: isize) -> IntMatrix {
|
||||
let rows = game.rows();
|
||||
let cols = game.cols();
|
||||
|
||||
if rows == 0 || cols == 0 {
|
||||
// Handle 0x0 case, other 0-dim cases panic in Matrix::from_vec
|
||||
return IntMatrix::from_vec(vec![], 0, 0);
|
||||
}
|
||||
|
||||
// Initialize with a matrix of 0s using from_vec.
|
||||
// This demonstrates creating an IntMatrix and then populating it.
|
||||
let mut shifted_layer = IntMatrix::from_vec(vec![0i32; rows * cols], rows, cols);
|
||||
|
||||
for r_target in 0..rows {
|
||||
// Iterate over cells in the *new* (target) shifted matrix
|
||||
for c_target in 0..cols {
|
||||
// Calculate where this target cell would have come from in the *original* game matrix
|
||||
let r_source = r_target as isize - dr;
|
||||
let c_source = c_target as isize - dc;
|
||||
|
||||
// Check if the source coordinates are within the bounds of the original game matrix
|
||||
if r_source >= 0
|
||||
&& r_source < rows as isize
|
||||
&& c_source >= 0
|
||||
&& c_source < cols as isize
|
||||
{
|
||||
// If the source cell in the original game was alive...
|
||||
if game[(r_source as usize, c_source as usize)] {
|
||||
// Demonstrates Index access on BoolMatrix
|
||||
// ...then this cell in the shifted layer is 1.
|
||||
shifted_layer[(r_target, c_target)] = 1; // Demonstrates IndexMut access on IntMatrix
|
||||
}
|
||||
}
|
||||
// Else (source is out of bounds): it remains 0, as initialized.
|
||||
}
|
||||
}
|
||||
shifted_layer // Return the constructed IntMatrix
|
||||
}
|
||||
|
||||
/// Calculates the next generation of Conway's Game of Life.
|
||||
///
|
||||
/// This implementation uses a broadcast-like approach by creating shifted layers
|
||||
/// for each neighbor and summing them up, then applying rules element-wise.
|
||||
///
|
||||
/// - `current_game`: A `&BoolMatrix` representing the current state (true=alive).
|
||||
///
|
||||
/// Returns: A new `BoolMatrix` for the next generation.
|
||||
pub fn game_of_life_next_frame(current_game: &BoolMatrix) -> BoolMatrix {
|
||||
let rows = current_game.rows();
|
||||
let cols = current_game.cols();
|
||||
|
||||
if rows == 0 && cols == 0 {
|
||||
return BoolMatrix::from_vec(vec![], 0, 0); // Return an empty BoolMatrix
|
||||
}
|
||||
// Define the 8 neighbor offsets (row_delta, col_delta)
|
||||
let neighbor_offsets: [(isize, isize); 8] = [
|
||||
(-1, -1),
|
||||
(-1, 0),
|
||||
(-1, 1),
|
||||
(0, -1),
|
||||
(0, 1),
|
||||
(1, -1),
|
||||
(1, 0),
|
||||
(1, 1),
|
||||
];
|
||||
|
||||
let (first_dr, first_dc) = neighbor_offsets[0];
|
||||
let mut neighbor_counts = get_shifted_neighbor_layer(current_game, first_dr, first_dc);
|
||||
|
||||
for i in 1..neighbor_offsets.len() {
|
||||
let (dr, dc) = neighbor_offsets[i];
|
||||
let next_neighbor_layer = get_shifted_neighbor_layer(current_game, dr, dc);
|
||||
neighbor_counts = neighbor_counts + next_neighbor_layer;
|
||||
}
|
||||
|
||||
let has_2_neighbors = neighbor_counts.eq_elem(2);
|
||||
let has_3_neighbors = neighbor_counts.eq_elem(3);
|
||||
|
||||
let has_2_or_3_neighbors = has_2_neighbors | has_3_neighbors.clone();
|
||||
|
||||
let survives = current_game & &has_2_or_3_neighbors;
|
||||
|
||||
let is_dead = !current_game;
|
||||
|
||||
let births = is_dead & &has_3_neighbors;
|
||||
|
||||
let next_frame_game = survives | births;
|
||||
|
||||
next_frame_game
|
||||
}
|
||||
|
||||
pub fn generate_glider(board: &mut BoolMatrix, board_size: usize) {
|
||||
// Initialize with a Glider pattern.
|
||||
// It demonstrates how to set specific cells in the matrix.
|
||||
// This demonstrates `IndexMut` for `current_board[(r, c)] = true;`.
|
||||
let mut rng = rng();
|
||||
let r_offset = rng.random_range(0..(board_size - 3));
|
||||
let c_offset = rng.random_range(0..(board_size - 3));
|
||||
if board.rows() >= r_offset + 3 && board.cols() >= c_offset + 3 {
|
||||
board[(r_offset + 0, c_offset + 1)] = true;
|
||||
board[(r_offset + 1, c_offset + 2)] = true;
|
||||
board[(r_offset + 2, c_offset + 0)] = true;
|
||||
board[(r_offset + 2, c_offset + 1)] = true;
|
||||
board[(r_offset + 2, c_offset + 2)] = true;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_pulsar(board: &mut BoolMatrix, board_size: usize) {
|
||||
// Initialize with a Pulsar pattern.
|
||||
// This demonstrates how to set specific cells in the matrix.
|
||||
// This demonstrates `IndexMut` for `current_board[(r, c)] = true;`.
|
||||
let mut rng = rng();
|
||||
let r_offset = rng.random_range(0..(board_size - 17));
|
||||
let c_offset = rng.random_range(0..(board_size - 17));
|
||||
if board.rows() >= r_offset + 17 && board.cols() >= c_offset + 17 {
|
||||
let pulsar_coords = [
|
||||
(2, 4),
|
||||
(2, 5),
|
||||
(2, 6),
|
||||
(2, 10),
|
||||
(2, 11),
|
||||
(2, 12),
|
||||
(4, 2),
|
||||
(4, 7),
|
||||
(4, 9),
|
||||
(4, 14),
|
||||
(5, 2),
|
||||
(5, 7),
|
||||
(5, 9),
|
||||
(5, 14),
|
||||
(6, 2),
|
||||
(6, 7),
|
||||
(6, 9),
|
||||
(6, 14),
|
||||
(7, 4),
|
||||
(7, 5),
|
||||
(7, 6),
|
||||
(7, 10),
|
||||
(7, 11),
|
||||
(7, 12),
|
||||
];
|
||||
for &(dr, dc) in pulsar_coords.iter() {
|
||||
board[(r_offset + dr, c_offset + dc)] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn detect_stable_state(
|
||||
current_board: &BoolMatrix,
|
||||
previous_board_state: &Option<BoolMatrix>,
|
||||
) -> bool {
|
||||
if let Some(ref prev_board) = previous_board_state {
|
||||
// `*prev_board == current_board` demonstrates `PartialEq` for `Matrix`.
|
||||
return *prev_board == *current_board;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn hash_board(board: &BoolMatrix, primes: Vec<i32>) -> usize {
|
||||
let board_ints_vec = board
|
||||
.data()
|
||||
.iter()
|
||||
.map(|&cell| if cell { 1 } else { 0 })
|
||||
.collect::<Vec<i32>>();
|
||||
|
||||
let ints_board = Matrix::from_vec(board_ints_vec, board.rows(), board.cols());
|
||||
|
||||
let primes_board = Matrix::from_vec(primes, ints_board.rows(), ints_board.cols());
|
||||
|
||||
let result = ints_board * primes_board;
|
||||
let result: i32 = result.data().iter().sum();
|
||||
result as usize
|
||||
}
|
||||
|
||||
pub fn detect_repeating_state(board_hashes: &mut Vec<usize>) -> bool {
|
||||
// so - detect alternating states. if 0==2, 1==3, 2==4, 3==5, 4==6, 5==7
|
||||
if board_hashes.len() < 4 {
|
||||
return false;
|
||||
}
|
||||
let mut result = false;
|
||||
if (board_hashes[0] == board_hashes[2]) && (board_hashes[0] == board_hashes[2]) {
|
||||
result = true;
|
||||
}
|
||||
// remove the 0th item
|
||||
board_hashes.remove(0);
|
||||
result
|
||||
}
|
||||
|
||||
pub fn add_simulated_activity(current_board: &mut BoolMatrix, board_size: usize) {
|
||||
for _ in 0..20 {
|
||||
generate_glider(current_board, board_size);
|
||||
}
|
||||
|
||||
// Generate a Pulsar pattern
|
||||
for _ in 0..10 {
|
||||
generate_pulsar(current_board, board_size);
|
||||
}
|
||||
}
|
||||
|
||||
// generate prime numbers
|
||||
pub fn generate_primes(n: i32) -> Vec<i32> {
|
||||
// I want to generate the first n primes
|
||||
let mut primes = Vec::new();
|
||||
let mut count = 0;
|
||||
let mut num = 2; // Start checking for primes from 2
|
||||
while count < n {
|
||||
let mut is_prime = true;
|
||||
for i in 2..=((num as f64).sqrt() as i32) {
|
||||
if num % i == 0 {
|
||||
is_prime = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if is_prime {
|
||||
primes.push(num);
|
||||
count += 1;
|
||||
}
|
||||
num += 1;
|
||||
}
|
||||
primes
|
||||
}
|
||||
|
||||
// --- Tests from previous example (can be kept or adapted) ---
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rustframe::matrix::{BoolMatrix, BoolOps}; // Assuming BoolOps is available for .count()
|
||||
|
||||
#[test]
|
||||
fn test_blinker_oscillator() {
|
||||
let initial_data = vec![false, true, false, false, true, false, false, true, false];
|
||||
let game1 = BoolMatrix::from_vec(initial_data.clone(), 3, 3);
|
||||
let expected_frame2_data = vec![false, false, false, true, true, true, false, false, false];
|
||||
let expected_game2 = BoolMatrix::from_vec(expected_frame2_data, 3, 3);
|
||||
let game2 = game_of_life_next_frame(&game1);
|
||||
assert_eq!(
|
||||
game2.data(),
|
||||
expected_game2.data(),
|
||||
"Frame 1 to Frame 2 failed for blinker"
|
||||
);
|
||||
let expected_game3 = BoolMatrix::from_vec(initial_data, 3, 3);
|
||||
let game3 = game_of_life_next_frame(&game2);
|
||||
assert_eq!(
|
||||
game3.data(),
|
||||
expected_game3.data(),
|
||||
"Frame 2 to Frame 3 failed for blinker"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_board_remains_empty() {
|
||||
let board_3x3_all_false = BoolMatrix::from_vec(vec![false; 9], 3, 3);
|
||||
let next_frame = game_of_life_next_frame(&board_3x3_all_false);
|
||||
assert_eq!(
|
||||
next_frame.count(),
|
||||
0,
|
||||
"All-false board should result in all-false"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_size_board() {
|
||||
let board_0x0 = BoolMatrix::from_vec(vec![], 0, 0);
|
||||
let next_frame = game_of_life_next_frame(&board_0x0);
|
||||
assert_eq!(next_frame.rows(), 0);
|
||||
assert_eq!(next_frame.cols(), 0);
|
||||
assert!(
|
||||
next_frame.data().is_empty(),
|
||||
"0x0 board should result in 0x0 board"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_still_life_block() {
|
||||
let block_data = vec![
|
||||
true, true, false, false, true, true, false, false, false, false, false, false, false,
|
||||
false, false, false,
|
||||
];
|
||||
let game_block = BoolMatrix::from_vec(block_data.clone(), 4, 4);
|
||||
let next_frame_block = game_of_life_next_frame(&game_block);
|
||||
assert_eq!(
|
||||
next_frame_block.data(),
|
||||
game_block.data(),
|
||||
"Block still life should remain unchanged"
|
||||
);
|
||||
}
|
||||
}
|
||||
66
examples/inferential_stats.rs
Normal file
66
examples/inferential_stats.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use rustframe::compute::stats::{anova, chi2_test, t_test};
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
/// Demonstrates simple inferential statistics tests.
|
||||
fn main() {
|
||||
t_test_demo();
|
||||
println!("\n-----\n");
|
||||
chi2_demo();
|
||||
println!("\n-----\n");
|
||||
anova_demo();
|
||||
}
|
||||
|
||||
fn t_test_demo() {
|
||||
println!("Two-sample t-test\n-----------------");
|
||||
let a = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
|
||||
let b = Matrix::from_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0], 1, 5);
|
||||
let (t, p) = t_test(&a, &b);
|
||||
println!("t statistic: {:.2}, p-value: {:.4}", t, p);
|
||||
}
|
||||
|
||||
fn chi2_demo() {
|
||||
println!("Chi-square test\n---------------");
|
||||
let observed = Matrix::from_vec(vec![12.0, 5.0, 8.0, 10.0], 2, 2);
|
||||
let (chi2, p) = chi2_test(&observed);
|
||||
println!("chi^2: {:.2}, p-value: {:.4}", chi2, p);
|
||||
}
|
||||
|
||||
fn anova_demo() {
|
||||
println!("One-way ANOVA\n-------------");
|
||||
let g1 = Matrix::from_vec(vec![1.0, 2.0, 3.0], 1, 3);
|
||||
let g2 = Matrix::from_vec(vec![2.0, 3.0, 4.0], 1, 3);
|
||||
let g3 = Matrix::from_vec(vec![3.0, 4.0, 5.0], 1, 3);
|
||||
let (f, p) = anova(vec![&g1, &g2, &g3]);
|
||||
println!("F statistic: {:.2}, p-value: {:.4}", f, p);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_t_test_demo() {
|
||||
let a = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
|
||||
let b = Matrix::from_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0], 1, 5);
|
||||
let (t, _p) = t_test(&a, &b);
|
||||
assert!((t + 5.0).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chi2_demo() {
|
||||
let observed = Matrix::from_vec(vec![12.0, 5.0, 8.0, 10.0], 2, 2);
|
||||
let (chi2, p) = chi2_test(&observed);
|
||||
assert!(chi2 > 0.0);
|
||||
assert!(p > 0.0 && p < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anova_demo() {
|
||||
let g1 = Matrix::from_vec(vec![1.0, 2.0, 3.0], 1, 3);
|
||||
let g2 = Matrix::from_vec(vec![2.0, 3.0, 4.0], 1, 3);
|
||||
let g3 = Matrix::from_vec(vec![3.0, 4.0, 5.0], 1, 3);
|
||||
let (f, p) = anova(vec![&g1, &g2, &g3]);
|
||||
assert!(f > 0.0);
|
||||
assert!(p > 0.0 && p < 1.0);
|
||||
}
|
||||
}
|
||||
65
examples/k_means.rs
Normal file
65
examples/k_means.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use rustframe::compute::models::k_means::KMeans;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
/// Two quick K-Means clustering demos.
|
||||
///
|
||||
/// Example 1 groups store locations on a city map.
|
||||
/// Example 2 segments customers by annual spending habits.
|
||||
fn main() {
|
||||
city_store_example();
|
||||
println!("\n-----\n");
|
||||
customer_spend_example();
|
||||
}
|
||||
|
||||
fn city_store_example() {
|
||||
println!("Example 1: store locations");
|
||||
|
||||
// (x, y) coordinates of stores around a city
|
||||
let raw = vec![
|
||||
1.0, 2.0, 1.5, 1.8, 5.0, 8.0, 8.0, 8.0, 1.0, 0.6, 9.0, 11.0, 8.0, 2.0, 10.0, 2.0, 9.0, 3.0,
|
||||
];
|
||||
let x = Matrix::from_rows_vec(raw, 9, 2);
|
||||
|
||||
// Group stores into two areas
|
||||
let (model, labels) = KMeans::fit(&x, 2, 100, 1e-4);
|
||||
|
||||
println!("Centres: {:?}", model.centroids.data());
|
||||
println!("Labels: {:?}", labels);
|
||||
|
||||
let new_points = Matrix::from_rows_vec(vec![0.0, 0.0, 8.0, 3.0], 2, 2);
|
||||
let pred = model.predict(&new_points);
|
||||
println!("New store assignments: {:?}", pred);
|
||||
}
|
||||
|
||||
fn customer_spend_example() {
|
||||
println!("Example 2: customer spending");
|
||||
|
||||
// (grocery spend, electronics spend) in dollars
|
||||
let raw = vec![
|
||||
200.0, 150.0, 220.0, 170.0, 250.0, 160.0, 800.0, 750.0, 820.0, 760.0, 790.0, 770.0,
|
||||
];
|
||||
let x = Matrix::from_rows_vec(raw, 6, 2);
|
||||
|
||||
let (model, labels) = KMeans::fit(&x, 2, 100, 1e-4);
|
||||
|
||||
println!("Centres: {:?}", model.centroids.data());
|
||||
println!("Labels: {:?}", labels);
|
||||
|
||||
let new_customers = Matrix::from_rows_vec(vec![230.0, 155.0, 810.0, 760.0], 2, 2);
|
||||
let pred = model.predict(&new_customers);
|
||||
println!("Cluster of new customers: {:?}", pred);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn k_means_store_locations() {
|
||||
let raw = vec![
|
||||
1.0, 2.0, 1.5, 1.8, 5.0, 8.0, 8.0, 8.0, 1.0, 0.6, 9.0, 11.0, 8.0, 2.0, 10.0, 2.0, 9.0, 3.0,
|
||||
];
|
||||
let x = Matrix::from_rows_vec(raw, 9, 2);
|
||||
let (model, labels) = KMeans::fit(&x, 2, 100, 1e-4);
|
||||
assert_eq!(labels.len(), 9);
|
||||
assert_eq!(model.centroids.rows(), 2);
|
||||
let new_points = Matrix::from_rows_vec(vec![0.0, 0.0, 8.0, 3.0], 2, 2);
|
||||
let pred = model.predict(&new_points);
|
||||
assert_eq!(pred.len(), 2);
|
||||
}
|
||||
118
examples/linear_regression.rs
Normal file
118
examples/linear_regression.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use rustframe::compute::models::linreg::LinReg;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
/// Two quick linear regression demonstrations.
|
||||
///
|
||||
/// Example 1 fits a model to predict house price from floor area.
|
||||
/// Example 2 adds number of bedrooms as a second feature.
|
||||
fn main() {
|
||||
example_one_feature();
|
||||
println!("\n-----\n");
|
||||
example_two_features();
|
||||
}
|
||||
|
||||
/// Price ~ floor area
|
||||
fn example_one_feature() {
|
||||
println!("Example 1: predict price from floor area only");
|
||||
|
||||
// Square meters of floor area for a few houses
|
||||
let sizes = vec![50.0, 60.0, 70.0, 80.0, 90.0, 100.0];
|
||||
// Thousands of dollars in sale price
|
||||
let prices = vec![150.0, 180.0, 210.0, 240.0, 270.0, 300.0];
|
||||
|
||||
// Each row is a sample with one feature
|
||||
let x = Matrix::from_vec(sizes.clone(), sizes.len(), 1);
|
||||
let y = Matrix::from_vec(prices.clone(), prices.len(), 1);
|
||||
|
||||
// Train with a small learning rate
|
||||
let mut model = LinReg::new(1);
|
||||
model.fit(&x, &y, 0.0005, 20000);
|
||||
|
||||
let preds = model.predict(&x);
|
||||
println!("Size (m^2) -> predicted price (k) vs actual");
|
||||
for i in 0..x.rows() {
|
||||
println!(
|
||||
"{:>3} -> {:>6.1} | {:>6.1}",
|
||||
sizes[i],
|
||||
preds[(i, 0)],
|
||||
prices[i]
|
||||
);
|
||||
}
|
||||
|
||||
let new_house = Matrix::from_vec(vec![120.0], 1, 1);
|
||||
let pred = model.predict(&new_house);
|
||||
println!("Predicted price for 120 m^2: {:.1}k", pred[(0, 0)]);
|
||||
}
|
||||
|
||||
/// Price ~ floor area + bedrooms
|
||||
fn example_two_features() {
|
||||
println!("Example 2: price from area and bedrooms");
|
||||
|
||||
// (size m^2, bedrooms) for each house
|
||||
let raw_x = vec![
|
||||
50.0, 2.0, 70.0, 2.0, 90.0, 3.0, 110.0, 3.0, 130.0, 4.0, 150.0, 4.0,
|
||||
];
|
||||
let prices = vec![160.0, 195.0, 250.0, 285.0, 320.0, 350.0];
|
||||
|
||||
let x = Matrix::from_rows_vec(raw_x, 6, 2);
|
||||
let y = Matrix::from_vec(prices.clone(), prices.len(), 1);
|
||||
|
||||
let mut model = LinReg::new(2);
|
||||
model.fit(&x, &y, 0.0001, 50000);
|
||||
|
||||
let preds = model.predict(&x);
|
||||
println!("size, beds -> predicted | actual (k)");
|
||||
for i in 0..x.rows() {
|
||||
let size = x[(i, 0)];
|
||||
let beds = x[(i, 1)];
|
||||
println!(
|
||||
"{:>3} m^2, {:>1} -> {:>6.1} | {:>6.1}",
|
||||
size,
|
||||
beds,
|
||||
preds[(i, 0)],
|
||||
prices[i]
|
||||
);
|
||||
}
|
||||
|
||||
let new_home = Matrix::from_rows_vec(vec![120.0, 3.0], 1, 2);
|
||||
let pred = model.predict(&new_home);
|
||||
println!(
|
||||
"Predicted price for 120 m^2 with 3 bedrooms: {:.1}k",
|
||||
pred[(0, 0)]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_regression_one_feature() {
|
||||
let sizes = vec![50.0, 60.0, 70.0, 80.0, 90.0, 100.0];
|
||||
let prices = vec![150.0, 180.0, 210.0, 240.0, 270.0, 300.0];
|
||||
let scaled: Vec<f64> = sizes.iter().map(|s| s / 100.0).collect();
|
||||
let x = Matrix::from_vec(scaled, sizes.len(), 1);
|
||||
let y = Matrix::from_vec(prices.clone(), prices.len(), 1);
|
||||
let mut model = LinReg::new(1);
|
||||
model.fit(&x, &y, 0.1, 2000);
|
||||
let preds = model.predict(&x);
|
||||
for i in 0..y.rows() {
|
||||
assert!((preds[(i, 0)] - prices[i]).abs() < 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linear_regression_two_features() {
|
||||
let raw_x = vec![
|
||||
50.0, 2.0, 70.0, 2.0, 90.0, 3.0, 110.0, 3.0, 130.0, 4.0, 150.0, 4.0,
|
||||
];
|
||||
let prices = vec![170.0, 210.0, 270.0, 310.0, 370.0, 410.0];
|
||||
let scaled_x: Vec<f64> = raw_x
|
||||
.chunks(2)
|
||||
.flat_map(|pair| vec![pair[0] / 100.0, pair[1]])
|
||||
.collect();
|
||||
let x = Matrix::from_rows_vec(scaled_x, 6, 2);
|
||||
let y = Matrix::from_vec(prices.clone(), prices.len(), 1);
|
||||
let mut model = LinReg::new(2);
|
||||
model.fit(&x, &y, 0.01, 50000);
|
||||
let preds = model.predict(&x);
|
||||
for i in 0..y.rows() {
|
||||
assert!((preds[(i, 0)] - prices[i]).abs() < 1.0);
|
||||
}
|
||||
}
|
||||
101
examples/logistic_regression.rs
Normal file
101
examples/logistic_regression.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use rustframe::compute::models::logreg::LogReg;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
/// Two binary classification demos using logistic regression.
|
||||
///
|
||||
/// Example 1 predicts exam success from hours studied.
|
||||
/// Example 2 predicts whether an online shopper will make a purchase.
|
||||
fn main() {
|
||||
student_passing_example();
|
||||
println!("\n-----\n");
|
||||
purchase_prediction_example();
|
||||
}
|
||||
|
||||
fn student_passing_example() {
|
||||
println!("Example 1: exam pass prediction");
|
||||
|
||||
// Hours studied for each student
|
||||
let hours = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||
// Label: 0 denotes failure and 1 denotes success
|
||||
let passed = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
|
||||
|
||||
let x = Matrix::from_vec(hours.clone(), hours.len(), 1);
|
||||
let y = Matrix::from_vec(passed.clone(), passed.len(), 1);
|
||||
|
||||
let mut model = LogReg::new(1);
|
||||
model.fit(&x, &y, 0.1, 10000);
|
||||
|
||||
let preds = model.predict(&x);
|
||||
println!("Hours -> pred | actual");
|
||||
for i in 0..x.rows() {
|
||||
println!(
|
||||
"{:>2} -> {} | {}",
|
||||
hours[i] as i32,
|
||||
preds[(i, 0)] as i32,
|
||||
passed[i] as i32
|
||||
);
|
||||
}
|
||||
|
||||
// Probability estimate for a new student
|
||||
let new_student = Matrix::from_vec(vec![5.5], 1, 1);
|
||||
let p = model.predict_proba(&new_student);
|
||||
println!("Probability of passing with 5.5h study: {:.2}", p[(0, 0)]);
|
||||
}
|
||||
|
||||
fn purchase_prediction_example() {
|
||||
println!("Example 2: purchase likelihood");
|
||||
|
||||
// minutes on site, pages viewed -> made a purchase?
|
||||
let raw_x = vec![1.0, 2.0, 3.0, 1.0, 2.0, 4.0, 5.0, 5.0, 3.5, 2.0, 6.0, 6.0];
|
||||
let bought = vec![0.0, 0.0, 0.0, 1.0, 0.0, 1.0];
|
||||
|
||||
let x = Matrix::from_rows_vec(raw_x, 6, 2);
|
||||
let y = Matrix::from_vec(bought.clone(), bought.len(), 1);
|
||||
|
||||
let mut model = LogReg::new(2);
|
||||
model.fit(&x, &y, 0.05, 20000);
|
||||
|
||||
let preds = model.predict(&x);
|
||||
println!("time, pages -> pred | actual");
|
||||
for i in 0..x.rows() {
|
||||
println!(
|
||||
"{:>4}m, {:>2} -> {} | {}",
|
||||
x[(i, 0)],
|
||||
x[(i, 1)] as i32,
|
||||
preds[(i, 0)] as i32,
|
||||
bought[i] as i32
|
||||
);
|
||||
}
|
||||
|
||||
let new_visit = Matrix::from_rows_vec(vec![4.0, 4.0], 1, 2);
|
||||
let p = model.predict_proba(&new_visit);
|
||||
println!("Prob of purchase for 4min/4pages: {:.2}", p[(0, 0)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_student_passing_example() {
|
||||
let hours = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||
let passed = vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
|
||||
let x = Matrix::from_vec(hours.clone(), hours.len(), 1);
|
||||
let y = Matrix::from_vec(passed.clone(), passed.len(), 1);
|
||||
let mut model = LogReg::new(1);
|
||||
model.fit(&x, &y, 0.1, 10000);
|
||||
let preds = model.predict(&x);
|
||||
for i in 0..y.rows() {
|
||||
assert_eq!(preds[(i, 0)], passed[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_purchase_prediction_example() {
|
||||
let raw_x = vec![1.0, 2.0, 3.0, 1.0, 2.0, 4.0, 5.0, 5.0, 3.5, 2.0, 6.0, 6.0];
|
||||
let bought = vec![0.0, 0.0, 0.0, 1.0, 0.0, 1.0];
|
||||
let x = Matrix::from_rows_vec(raw_x, 6, 2);
|
||||
let y = Matrix::from_vec(bought.clone(), bought.len(), 1);
|
||||
let mut model = LogReg::new(2);
|
||||
model.fit(&x, &y, 0.05, 20000);
|
||||
let preds = model.predict(&x);
|
||||
for i in 0..y.rows() {
|
||||
assert_eq!(preds[(i, 0)], bought[i]);
|
||||
}
|
||||
}
|
||||
60
examples/pca.rs
Normal file
60
examples/pca.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use rustframe::compute::models::pca::PCA;
|
||||
use rustframe::matrix::Matrix;
|
||||
|
||||
/// Two dimensionality reduction examples using PCA.
|
||||
///
|
||||
/// Example 1 reduces 3D sensor readings to two components.
|
||||
/// Example 2 compresses a small four-feature dataset.
|
||||
fn main() {
|
||||
sensor_demo();
|
||||
println!("\n-----\n");
|
||||
finance_demo();
|
||||
}
|
||||
|
||||
fn sensor_demo() {
|
||||
println!("Example 1: 3D sensor data");
|
||||
|
||||
// Ten 3D observations from an accelerometer
|
||||
let raw = vec![
|
||||
2.5, 2.4, 0.5, 0.5, 0.7, 1.5, 2.2, 2.9, 0.7, 1.9, 2.2, 1.0, 3.1, 3.0, 0.6, 2.3, 2.7, 0.9,
|
||||
2.0, 1.6, 1.1, 1.0, 1.1, 1.9, 1.5, 1.6, 2.2, 1.1, 0.9, 2.1,
|
||||
];
|
||||
let x = Matrix::from_rows_vec(raw, 10, 3);
|
||||
|
||||
let pca = PCA::fit(&x, 2, 0);
|
||||
let reduced = pca.transform(&x);
|
||||
|
||||
println!("Components: {:?}", pca.components.data());
|
||||
println!("First row -> {:.2?}", [reduced[(0, 0)], reduced[(0, 1)]]);
|
||||
}
|
||||
|
||||
fn finance_demo() {
|
||||
println!("Example 2: 4D finance data");
|
||||
|
||||
// Four daily percentage returns of different stocks
|
||||
let raw = vec![
|
||||
0.2, 0.1, -0.1, 0.0, 0.3, 0.2, -0.2, 0.1, 0.1, 0.0, -0.1, -0.1, 0.4, 0.3, -0.3, 0.2, 0.0,
|
||||
-0.1, 0.1, -0.1,
|
||||
];
|
||||
let x = Matrix::from_rows_vec(raw, 5, 4);
|
||||
|
||||
// Keep two principal components
|
||||
let pca = PCA::fit(&x, 2, 0);
|
||||
let reduced = pca.transform(&x);
|
||||
|
||||
println!("Reduced shape: {:?}", reduced.shape());
|
||||
println!("First row -> {:.2?}", [reduced[(0, 0)], reduced[(0, 1)]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sensor_demo() {
|
||||
let raw = vec![
|
||||
2.5, 2.4, 0.5, 0.5, 0.7, 1.5, 2.2, 2.9, 0.7, 1.9, 2.2, 1.0, 3.1, 3.0, 0.6, 2.3, 2.7, 0.9,
|
||||
2.0, 1.6, 1.1, 1.0, 1.1, 1.9, 1.5, 1.6, 2.2, 1.1, 0.9, 2.1,
|
||||
];
|
||||
let x = Matrix::from_rows_vec(raw, 10, 3);
|
||||
let pca = PCA::fit(&x, 2, 0);
|
||||
let reduced = pca.transform(&x);
|
||||
assert_eq!(reduced.rows(), 10);
|
||||
assert_eq!(reduced.cols(), 2);
|
||||
}
|
||||
67
examples/random_demo.rs
Normal file
67
examples/random_demo.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use rustframe::random::{crypto_rng, rng, Rng, SliceRandom};
|
||||
|
||||
/// Demonstrates basic usage of the random number generators.
|
||||
///
|
||||
/// It showcases uniform ranges, booleans, normal distribution,
|
||||
/// shuffling and the cryptographically secure generator.
|
||||
fn main() {
|
||||
basic_usage();
|
||||
println!("\n-----\n");
|
||||
normal_demo();
|
||||
println!("\n-----\n");
|
||||
shuffle_demo();
|
||||
}
|
||||
|
||||
fn basic_usage() {
|
||||
println!("Basic PRNG usage\n----------------");
|
||||
let mut prng = rng();
|
||||
println!("random u64 : {}", prng.next_u64());
|
||||
println!("range [10,20): {}", prng.random_range(10..20));
|
||||
println!("bool : {}", prng.gen_bool());
|
||||
}
|
||||
|
||||
fn normal_demo() {
|
||||
println!("Normal distribution\n-------------------");
|
||||
let mut prng = rng();
|
||||
for _ in 0..3 {
|
||||
let v = prng.normal(0.0, 1.0);
|
||||
println!("sample: {:.3}", v);
|
||||
}
|
||||
}
|
||||
|
||||
fn shuffle_demo() {
|
||||
println!("Slice shuffling\n----------------");
|
||||
let mut prng = rng();
|
||||
let mut data = [1, 2, 3, 4, 5];
|
||||
data.shuffle(&mut prng);
|
||||
println!("shuffled: {:?}", data);
|
||||
|
||||
let mut secure = crypto_rng();
|
||||
let byte = secure.random_range(0..256usize);
|
||||
println!("crypto byte: {}", byte);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rustframe::random::{CryptoRng, Prng};
|
||||
|
||||
#[test]
|
||||
fn test_basic_usage_range_bounds() {
|
||||
let mut rng = Prng::new(1);
|
||||
for _ in 0..50 {
|
||||
let v = rng.random_range(5..10);
|
||||
assert!(v >= 5 && v < 10);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_byte_bounds() {
|
||||
let mut rng = CryptoRng::new();
|
||||
for _ in 0..50 {
|
||||
let v = rng.random_range(0..256usize);
|
||||
assert!(v < 256);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
57
examples/random_stats.rs
Normal file
57
examples/random_stats.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use rustframe::random::{crypto_rng, rng, Rng};
|
||||
|
||||
/// Demonstrates simple statistical checks on random number generators.
|
||||
fn main() {
|
||||
chi_square_demo();
|
||||
println!("\n-----\n");
|
||||
monobit_demo();
|
||||
}
|
||||
|
||||
fn chi_square_demo() {
|
||||
println!("Chi-square test on PRNG");
|
||||
let mut rng = rng();
|
||||
let mut counts = [0usize; 10];
|
||||
let samples = 10000;
|
||||
for _ in 0..samples {
|
||||
let v = rng.random_range(0..10usize);
|
||||
counts[v] += 1;
|
||||
}
|
||||
let expected = samples as f64 / 10.0;
|
||||
let chi2: f64 = counts
|
||||
.iter()
|
||||
.map(|&c| {
|
||||
let diff = c as f64 - expected;
|
||||
diff * diff / expected
|
||||
})
|
||||
.sum();
|
||||
println!("counts: {:?}", counts);
|
||||
println!("chi-square: {:.3}", chi2);
|
||||
}
|
||||
|
||||
fn monobit_demo() {
|
||||
println!("Monobit test on crypto RNG");
|
||||
let mut rng = crypto_rng();
|
||||
let mut ones = 0usize;
|
||||
let samples = 1000;
|
||||
for _ in 0..samples {
|
||||
ones += rng.next_u64().count_ones() as usize;
|
||||
}
|
||||
let ratio = ones as f64 / (samples as f64 * 64.0);
|
||||
println!("ones ratio: {:.4}", ratio);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chi_square_demo_runs() {
|
||||
chi_square_demo();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_monobit_demo_runs() {
|
||||
monobit_demo();
|
||||
}
|
||||
}
|
||||
|
||||
93
examples/stats_overview.rs
Normal file
93
examples/stats_overview.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use rustframe::compute::stats::{
|
||||
chi2_test, covariance, covariance_matrix, mean, median, pearson, percentile, stddev, t_test,
|
||||
};
|
||||
use rustframe::matrix::{Axis, Matrix};
|
||||
|
||||
/// Demonstrates some of the statistics utilities in Rustframe.
|
||||
///
|
||||
/// The example is split into three parts:
|
||||
/// - Basic descriptive statistics on a small data set
|
||||
/// - Covariance and correlation calculations
|
||||
/// - Simple inferential tests (t-test and chi-square)
|
||||
fn main() {
|
||||
descriptive_demo();
|
||||
println!("\n-----\n");
|
||||
correlation_demo();
|
||||
println!("\n-----\n");
|
||||
inferential_demo();
|
||||
}
|
||||
|
||||
fn descriptive_demo() {
|
||||
println!("Descriptive statistics\n----------------------");
|
||||
let data = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
|
||||
println!("mean : {:.2}", mean(&data));
|
||||
println!("std dev : {:.2}", stddev(&data));
|
||||
println!("median : {:.2}", median(&data));
|
||||
println!("25th percentile: {:.2}", percentile(&data, 25.0));
|
||||
}
|
||||
|
||||
fn correlation_demo() {
|
||||
println!("Covariance and Correlation\n--------------------------");
|
||||
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let y = Matrix::from_vec(vec![1.0, 2.0, 3.0, 5.0], 2, 2);
|
||||
let cov = covariance(&x, &y);
|
||||
let cov_mat = covariance_matrix(&x, Axis::Col);
|
||||
let corr = pearson(&x, &y);
|
||||
println!("covariance : {:.2}", cov);
|
||||
println!("cov matrix : {:?}", cov_mat.data());
|
||||
println!("pearson r : {:.2}", corr);
|
||||
}
|
||||
|
||||
fn inferential_demo() {
|
||||
println!("Inferential statistics\n----------------------");
|
||||
let s1 = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
|
||||
let s2 = Matrix::from_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0], 1, 5);
|
||||
let (t_stat, t_p) = t_test(&s1, &s2);
|
||||
println!("t statistic : {:.2}, p-value: {:.4}", t_stat, t_p);
|
||||
|
||||
let observed = Matrix::from_vec(vec![12.0, 5.0, 8.0, 10.0], 2, 2);
|
||||
let (chi2, chi_p) = chi2_test(&observed);
|
||||
println!("chi^2 : {:.2}, p-value: {:.4}", chi2, chi_p);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
const EPS: f64 = 1e-8;
|
||||
|
||||
#[test]
|
||||
fn test_descriptive_demo() {
|
||||
let data = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
|
||||
assert!((mean(&data) - 3.0).abs() < EPS);
|
||||
assert!((stddev(&data) - 1.4142135623730951).abs() < EPS);
|
||||
assert!((median(&data) - 3.0).abs() < EPS);
|
||||
assert!((percentile(&data, 25.0) - 2.0).abs() < EPS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_correlation_demo() {
|
||||
let x = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let y = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 5.0], 2, 2);
|
||||
let cov = covariance(&x, &y);
|
||||
assert!((cov - 1.625).abs() < EPS);
|
||||
let cov_mat = covariance_matrix(&x, Axis::Col);
|
||||
assert!((cov_mat.get(0, 0) - 2.0).abs() < EPS);
|
||||
assert!((cov_mat.get(1, 1) - 2.0).abs() < EPS);
|
||||
let corr = pearson(&x, &y);
|
||||
assert!((corr - 0.9827076298239908).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inferential_demo() {
|
||||
let s1 = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
|
||||
let s2 = Matrix::from_rows_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0], 1, 5);
|
||||
let (t_stat, p_value) = t_test(&s1, &s2);
|
||||
assert!((t_stat + 5.0).abs() < 1e-5);
|
||||
assert!(p_value > 0.0 && p_value < 1.0);
|
||||
|
||||
let observed = Matrix::from_rows_vec(vec![12.0, 5.0, 8.0, 10.0], 2, 2);
|
||||
let (chi2, p) = chi2_test(&observed);
|
||||
assert!(chi2 > 0.0);
|
||||
assert!(p > 0.0 && p < 1.0);
|
||||
}
|
||||
}
|
||||
16
src/compute/mod.rs
Normal file
16
src/compute/mod.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
//! Algorithms and statistical utilities built on top of the core matrices.
|
||||
//!
|
||||
//! This module groups together machine‑learning models and statistical helper
|
||||
//! functions. For quick access to basic statistics see [`stats`](crate::compute::stats), while
|
||||
//! [`models`](crate::compute::models) contains small learning algorithms.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::stats;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let m = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1);
|
||||
//! assert_eq!(stats::mean(&m), 2.0);
|
||||
//! ```
|
||||
pub mod models;
|
||||
|
||||
pub mod stats;
|
||||
148
src/compute/models/activations.rs
Normal file
148
src/compute/models/activations.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
//! Common activation functions used in neural networks.
|
||||
//!
|
||||
//! Functions operate element-wise on [`Matrix`] values.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::models::activations::sigmoid;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let x = Matrix::from_vec(vec![0.0], 1, 1);
|
||||
//! let y = sigmoid(&x);
|
||||
//! assert!((y.get(0,0) - 0.5).abs() < 1e-6);
|
||||
//! ```
|
||||
use crate::matrix::{Matrix, SeriesOps};
|
||||
|
||||
pub fn sigmoid(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
x.map(|v| 1.0 / (1.0 + (-v).exp()))
|
||||
}
|
||||
|
||||
pub fn dsigmoid(y: &Matrix<f64>) -> Matrix<f64> {
|
||||
// derivative w.r.t. pre-activation; takes y = sigmoid(x)
|
||||
y.map(|v| v * (1.0 - v))
|
||||
}
|
||||
|
||||
pub fn relu(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
x.map(|v| if v > 0.0 { v } else { 0.0 })
|
||||
}
|
||||
|
||||
pub fn drelu(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
x.map(|v| if v > 0.0 { 1.0 } else { 0.0 })
|
||||
}
|
||||
|
||||
pub fn leaky_relu(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
x.map(|v| if v > 0.0 { v } else { 0.01 * v })
|
||||
}
|
||||
|
||||
pub fn dleaky_relu(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
x.map(|v| if v > 0.0 { 1.0 } else { 0.01 })
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper function to round all elements in a matrix to n decimal places
|
||||
fn _round_matrix(mat: &Matrix<f64>, decimals: u32) -> Matrix<f64> {
|
||||
let factor = 10f64.powi(decimals as i32);
|
||||
let rounded: Vec<f64> = mat
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|v| (v * factor).round() / factor)
|
||||
.collect();
|
||||
Matrix::from_vec(rounded, mat.rows(), mat.cols())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sigmoid() {
|
||||
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![0.26894142, 0.5, 0.73105858], 3, 1);
|
||||
let result = sigmoid(&x);
|
||||
assert_eq!(_round_matrix(&result, 6), _round_matrix(&expected, 6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sigmoid_edge_case() {
|
||||
let x = Matrix::from_vec(vec![-1000.0, 0.0, 1000.0], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![0.0, 0.5, 1.0], 3, 1);
|
||||
let result = sigmoid(&x);
|
||||
|
||||
for (r, e) in result.data().iter().zip(expected.data().iter()) {
|
||||
assert!((r - e).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relu() {
|
||||
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
|
||||
assert_eq!(relu(&x), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relu_edge_case() {
|
||||
let x = Matrix::from_vec(vec![-1e-10, 0.0, 1e10], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![0.0, 0.0, 1e10], 3, 1);
|
||||
assert_eq!(relu(&x), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dsigmoid() {
|
||||
let y = Matrix::from_vec(vec![0.26894142, 0.5, 0.73105858], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![0.19661193, 0.25, 0.19661193], 3, 1);
|
||||
let result = dsigmoid(&y);
|
||||
assert_eq!(_round_matrix(&result, 6), _round_matrix(&expected, 6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dsigmoid_edge_case() {
|
||||
let y = Matrix::from_vec(vec![0.0, 0.5, 1.0], 3, 1); // Assume these are outputs from sigmoid(x)
|
||||
let expected = Matrix::from_vec(vec![0.0, 0.25, 0.0], 3, 1);
|
||||
let result = dsigmoid(&y);
|
||||
|
||||
for (r, e) in result.data().iter().zip(expected.data().iter()) {
|
||||
assert!((r - e).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_drelu() {
|
||||
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
|
||||
assert_eq!(drelu(&x), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_drelu_edge_case() {
|
||||
let x = Matrix::from_vec(vec![-1e-10, 0.0, 1e10], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
|
||||
assert_eq!(drelu(&x), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_leaky_relu() {
|
||||
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![-0.01, 0.0, 1.0], 3, 1);
|
||||
assert_eq!(leaky_relu(&x), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_leaky_relu_edge_case() {
|
||||
let x = Matrix::from_vec(vec![-1e-10, 0.0, 1e10], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![-1e-12, 0.0, 1e10], 3, 1);
|
||||
assert_eq!(leaky_relu(&x), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dleaky_relu() {
|
||||
let x = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![0.01, 0.01, 1.0], 3, 1);
|
||||
assert_eq!(dleaky_relu(&x), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dleaky_relu_edge_case() {
|
||||
let x = Matrix::from_vec(vec![-1e-10, 0.0, 1e10], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![0.01, 0.01, 1.0], 3, 1);
|
||||
assert_eq!(dleaky_relu(&x), expected);
|
||||
}
|
||||
}
|
||||
551
src/compute/models/dense_nn.rs
Normal file
551
src/compute/models/dense_nn.rs
Normal file
@@ -0,0 +1,551 @@
|
||||
//! A minimal dense neural network implementation for educational purposes.
|
||||
//!
|
||||
//! Layers operate on [`Matrix`] values and support ReLU and Sigmoid
|
||||
//! activations. This is not meant to be a performant deep‑learning framework
|
||||
//! but rather a small example of how the surrounding matrix utilities can be
|
||||
//! composed.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::models::dense_nn::{ActivationKind, DenseNN, DenseNNConfig, InitializerKind, LossKind};
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! // Tiny network with one input and one output neuron.
|
||||
//! let config = DenseNNConfig {
|
||||
//! input_size: 1,
|
||||
//! hidden_layers: vec![],
|
||||
//! output_size: 1,
|
||||
//! activations: vec![ActivationKind::Relu],
|
||||
//! initializer: InitializerKind::Uniform(0.5),
|
||||
//! loss: LossKind::MSE,
|
||||
//! learning_rate: 0.1,
|
||||
//! epochs: 1,
|
||||
//! };
|
||||
//! let mut nn = DenseNN::new(config);
|
||||
//! let x = Matrix::from_vec(vec![1.0, 2.0], 2, 1);
|
||||
//! let y = Matrix::from_vec(vec![2.0, 3.0], 2, 1);
|
||||
//! nn.train(&x, &y);
|
||||
//! ```
|
||||
use crate::compute::models::activations::{drelu, relu, sigmoid};
|
||||
use crate::matrix::{Matrix, SeriesOps};
|
||||
use crate::random::prelude::*;
|
||||
|
||||
/// Supported activation functions
|
||||
#[derive(Clone)]
|
||||
pub enum ActivationKind {
|
||||
Relu,
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
}
|
||||
|
||||
impl ActivationKind {
|
||||
/// Apply activation elementwise
|
||||
pub fn forward(&self, z: &Matrix<f64>) -> Matrix<f64> {
|
||||
match self {
|
||||
ActivationKind::Relu => relu(z),
|
||||
ActivationKind::Sigmoid => sigmoid(z),
|
||||
ActivationKind::Tanh => z.map(|v| v.tanh()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute elementwise derivative w.r.t. pre-activation z
|
||||
pub fn derivative(&self, z: &Matrix<f64>) -> Matrix<f64> {
|
||||
match self {
|
||||
ActivationKind::Relu => drelu(z),
|
||||
ActivationKind::Sigmoid => {
|
||||
let s = sigmoid(z);
|
||||
s.zip(&s, |si, sj| si * (1.0 - sj))
|
||||
}
|
||||
ActivationKind::Tanh => z.map(|v| 1.0 - v.tanh().powi(2)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Weight initialization schemes
|
||||
#[derive(Clone)]
|
||||
pub enum InitializerKind {
|
||||
/// Uniform(-limit .. limit)
|
||||
Uniform(f64),
|
||||
/// Xavier/Glorot uniform
|
||||
Xavier,
|
||||
/// He (Kaiming) uniform
|
||||
He,
|
||||
}
|
||||
|
||||
impl InitializerKind {
|
||||
pub fn initialize(&self, rows: usize, cols: usize) -> Matrix<f64> {
|
||||
let mut rng = rng();
|
||||
let fan_in = rows;
|
||||
let fan_out = cols;
|
||||
let limit = match self {
|
||||
InitializerKind::Uniform(l) => *l,
|
||||
InitializerKind::Xavier => (6.0 / (fan_in + fan_out) as f64).sqrt(),
|
||||
InitializerKind::He => (2.0 / fan_in as f64).sqrt(),
|
||||
};
|
||||
let data = (0..rows * cols)
|
||||
.map(|_| rng.random_range(-limit..limit))
|
||||
.collect::<Vec<_>>();
|
||||
Matrix::from_vec(data, rows, cols)
|
||||
}
|
||||
}
|
||||
|
||||
/// Supported losses
|
||||
#[derive(Clone)]
|
||||
pub enum LossKind {
|
||||
/// Mean Squared Error: L = 1/m * sum((y_hat - y)^2)
|
||||
MSE,
|
||||
/// Binary Cross-Entropy: L = -1/m * sum(y*log(y_hat) + (1-y)*log(1-y_hat))
|
||||
BCE,
|
||||
}
|
||||
|
||||
impl LossKind {
|
||||
/// Compute gradient dL/dy_hat (before applying activation derivative)
|
||||
pub fn gradient(&self, y_hat: &Matrix<f64>, y: &Matrix<f64>) -> Matrix<f64> {
|
||||
let m = y.rows() as f64;
|
||||
match self {
|
||||
LossKind::MSE => (y_hat - y) * (2.0 / m),
|
||||
LossKind::BCE => (y_hat - y) * (1.0 / m),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for a dense neural network
|
||||
pub struct DenseNNConfig {
|
||||
pub input_size: usize,
|
||||
pub hidden_layers: Vec<usize>,
|
||||
/// Must have length = hidden_layers.len() + 1
|
||||
pub activations: Vec<ActivationKind>,
|
||||
pub output_size: usize,
|
||||
pub initializer: InitializerKind,
|
||||
pub loss: LossKind,
|
||||
pub learning_rate: f64,
|
||||
pub epochs: usize,
|
||||
}
|
||||
|
||||
/// A multi-layer perceptron with full configurability
|
||||
pub struct DenseNN {
|
||||
weights: Vec<Matrix<f64>>,
|
||||
biases: Vec<Matrix<f64>>,
|
||||
activations: Vec<ActivationKind>,
|
||||
loss: LossKind,
|
||||
lr: f64,
|
||||
epochs: usize,
|
||||
}
|
||||
|
||||
impl DenseNN {
|
||||
/// Build a new DenseNN from the given configuration
|
||||
pub fn new(config: DenseNNConfig) -> Self {
|
||||
let mut sizes = vec![config.input_size];
|
||||
sizes.extend(&config.hidden_layers);
|
||||
sizes.push(config.output_size);
|
||||
|
||||
assert_eq!(
|
||||
config.activations.len(),
|
||||
sizes.len() - 1,
|
||||
"Number of activation functions must match number of layers"
|
||||
);
|
||||
|
||||
let mut weights = Vec::with_capacity(sizes.len() - 1);
|
||||
let mut biases = Vec::with_capacity(sizes.len() - 1);
|
||||
|
||||
for i in 0..sizes.len() - 1 {
|
||||
let w = config.initializer.initialize(sizes[i], sizes[i + 1]);
|
||||
let b = Matrix::zeros(1, sizes[i + 1]);
|
||||
weights.push(w);
|
||||
biases.push(b);
|
||||
}
|
||||
|
||||
DenseNN {
|
||||
weights,
|
||||
biases,
|
||||
activations: config.activations,
|
||||
loss: config.loss,
|
||||
lr: config.learning_rate,
|
||||
epochs: config.epochs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a full forward pass, returning pre-activations (z) and activations (a)
|
||||
fn forward_full(&self, x: &Matrix<f64>) -> (Vec<Matrix<f64>>, Vec<Matrix<f64>>) {
|
||||
let mut zs = Vec::with_capacity(self.weights.len());
|
||||
let mut activs = Vec::with_capacity(self.weights.len() + 1);
|
||||
activs.push(x.clone());
|
||||
|
||||
let mut a = x.clone();
|
||||
for (i, (w, b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
|
||||
let z = &a.dot(w) + &Matrix::repeat_rows(b, a.rows());
|
||||
let a_next = self.activations[i].forward(&z);
|
||||
zs.push(z);
|
||||
activs.push(a_next.clone());
|
||||
a = a_next;
|
||||
}
|
||||
|
||||
(zs, activs)
|
||||
}
|
||||
|
||||
/// Train the network on inputs X and targets Y
|
||||
pub fn train(&mut self, x: &Matrix<f64>, y: &Matrix<f64>) {
|
||||
let m = x.rows() as f64;
|
||||
for _ in 0..self.epochs {
|
||||
let (zs, activs) = self.forward_full(x);
|
||||
let y_hat = activs.last().unwrap().clone();
|
||||
|
||||
// Initial delta (dL/dz) on output
|
||||
let mut delta = match self.loss {
|
||||
LossKind::BCE => self.loss.gradient(&y_hat, y),
|
||||
LossKind::MSE => {
|
||||
let grad = self.loss.gradient(&y_hat, y);
|
||||
let dz = self
|
||||
.activations
|
||||
.last()
|
||||
.unwrap()
|
||||
.derivative(zs.last().unwrap());
|
||||
grad.zip(&dz, |g, da| g * da)
|
||||
}
|
||||
};
|
||||
|
||||
// Backpropagate through layers
|
||||
for l in (0..self.weights.len()).rev() {
|
||||
let a_prev = &activs[l];
|
||||
let dw = a_prev.transpose().dot(&delta) / m;
|
||||
let db = Matrix::from_vec(delta.sum_vertical(), 1, delta.cols()) / m;
|
||||
|
||||
// Update weights & biases
|
||||
self.weights[l] = &self.weights[l] - &(dw * self.lr);
|
||||
self.biases[l] = &self.biases[l] - &(db * self.lr);
|
||||
|
||||
// Propagate delta to previous layer
|
||||
if l > 0 {
|
||||
let w_t = self.weights[l].transpose();
|
||||
let da = self.activations[l - 1].derivative(&zs[l - 1]);
|
||||
delta = delta.dot(&w_t).zip(&da, |d, a| d * a);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a forward pass and return the network's output
|
||||
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
||||
let mut a = x.clone();
|
||||
for (i, (w, b)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
|
||||
let z = &a.dot(w) + &Matrix::repeat_rows(b, a.rows());
|
||||
a = self.activations[i].forward(&z);
|
||||
}
|
||||
a
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::matrix::Matrix;
|
||||
|
||||
/// Compute MSE = 1/m * Σ (ŷ - y)²
|
||||
fn mse_loss(y_hat: &Matrix<f64>, y: &Matrix<f64>) -> f64 {
|
||||
let m = y.rows() as f64;
|
||||
y_hat
|
||||
.zip(y, |yh, yv| (yh - yv).powi(2))
|
||||
.data()
|
||||
.iter()
|
||||
.sum::<f64>()
|
||||
/ m
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_predict_shape() {
|
||||
let config = DenseNNConfig {
|
||||
input_size: 1,
|
||||
hidden_layers: vec![2],
|
||||
activations: vec![ActivationKind::Relu, ActivationKind::Sigmoid],
|
||||
output_size: 1,
|
||||
initializer: InitializerKind::Uniform(0.1),
|
||||
loss: LossKind::MSE,
|
||||
learning_rate: 0.01,
|
||||
epochs: 0,
|
||||
};
|
||||
let model = DenseNN::new(config);
|
||||
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1);
|
||||
let preds = model.predict(&x);
|
||||
assert_eq!(preds.rows(), 3);
|
||||
assert_eq!(preds.cols(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Number of activation functions must match number of layers")]
|
||||
fn test_invalid_activation_count() {
|
||||
let config = DenseNNConfig {
|
||||
input_size: 2,
|
||||
hidden_layers: vec![3],
|
||||
activations: vec![ActivationKind::Relu], // Only one activation for two layers
|
||||
output_size: 1,
|
||||
initializer: InitializerKind::Uniform(0.1),
|
||||
loss: LossKind::MSE,
|
||||
learning_rate: 0.01,
|
||||
epochs: 0,
|
||||
};
|
||||
let _model = DenseNN::new(config);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_train_no_epochs_does_nothing() {
|
||||
let config = DenseNNConfig {
|
||||
input_size: 1,
|
||||
hidden_layers: vec![2],
|
||||
activations: vec![ActivationKind::Relu, ActivationKind::Sigmoid],
|
||||
output_size: 1,
|
||||
initializer: InitializerKind::Uniform(0.1),
|
||||
loss: LossKind::MSE,
|
||||
learning_rate: 0.01,
|
||||
epochs: 0,
|
||||
};
|
||||
let mut model = DenseNN::new(config);
|
||||
let x = Matrix::from_vec(vec![0.0, 1.0], 2, 1);
|
||||
let y = Matrix::from_vec(vec![0.0, 1.0], 2, 1);
|
||||
|
||||
let before = model.predict(&x);
|
||||
model.train(&x, &y);
|
||||
let after = model.predict(&x);
|
||||
|
||||
for i in 0..before.rows() {
|
||||
for j in 0..before.cols() {
|
||||
// "prediction changed despite 0 epochs"
|
||||
assert!((before[(i, j)] - after[(i, j)]).abs() < 1e-12);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_train_one_epoch_changes_predictions() {
|
||||
// Single-layer sigmoid regression so gradients flow.
|
||||
let config = DenseNNConfig {
|
||||
input_size: 1,
|
||||
hidden_layers: vec![],
|
||||
activations: vec![ActivationKind::Sigmoid],
|
||||
output_size: 1,
|
||||
initializer: InitializerKind::Uniform(0.1),
|
||||
loss: LossKind::MSE,
|
||||
learning_rate: 1.0,
|
||||
epochs: 1,
|
||||
};
|
||||
let mut model = DenseNN::new(config);
|
||||
|
||||
let x = Matrix::from_vec(vec![0.0, 1.0], 2, 1);
|
||||
let y = Matrix::from_vec(vec![0.0, 1.0], 2, 1);
|
||||
|
||||
let before = model.predict(&x);
|
||||
model.train(&x, &y);
|
||||
let after = model.predict(&x);
|
||||
|
||||
// At least one of the two outputs must move by >ϵ
|
||||
let mut moved = false;
|
||||
for i in 0..before.rows() {
|
||||
if (before[(i, 0)] - after[(i, 0)]).abs() > 1e-8 {
|
||||
moved = true;
|
||||
}
|
||||
}
|
||||
assert!(moved, "predictions did not change after 1 epoch");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_training_reduces_mse_loss() {
|
||||
// Same single‐layer sigmoid setup; check loss goes down.
|
||||
let config = DenseNNConfig {
|
||||
input_size: 1,
|
||||
hidden_layers: vec![],
|
||||
activations: vec![ActivationKind::Sigmoid],
|
||||
output_size: 1,
|
||||
initializer: InitializerKind::Uniform(0.1),
|
||||
loss: LossKind::MSE,
|
||||
learning_rate: 1.0,
|
||||
epochs: 10,
|
||||
};
|
||||
let mut model = DenseNN::new(config);
|
||||
|
||||
let x = Matrix::from_vec(vec![0.0, 1.0, 0.5], 3, 1);
|
||||
let y = Matrix::from_vec(vec![0.0, 1.0, 0.5], 3, 1);
|
||||
|
||||
let before_preds = model.predict(&x);
|
||||
let before_loss = mse_loss(&before_preds, &y);
|
||||
|
||||
model.train(&x, &y);
|
||||
|
||||
let after_preds = model.predict(&x);
|
||||
let after_loss = mse_loss(&after_preds, &y);
|
||||
|
||||
// MSE did not decrease (before: {}, after: {})
|
||||
assert!(after_loss < before_loss);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_activation_kind_forward_tanh() {
|
||||
let input = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![-0.76159415595, 0.0, 0.76159415595], 3, 1);
|
||||
let output = ActivationKind::Tanh.forward(&input);
|
||||
|
||||
for i in 0..input.rows() {
|
||||
for j in 0..input.cols() {
|
||||
// Tanh forward output mismatch at ({}, {})
|
||||
assert!((output[(i, j)] - expected[(i, j)]).abs() < 1e-9);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_activation_kind_derivative_relu() {
|
||||
let input = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
|
||||
let output = ActivationKind::Relu.derivative(&input);
|
||||
|
||||
for i in 0..input.rows() {
|
||||
for j in 0..input.cols() {
|
||||
// "ReLU derivative output mismatch at ({}, {})"
|
||||
assert!((output[(i, j)] - expected[(i, j)]).abs() < 1e-9);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_activation_kind_derivative_tanh() {
|
||||
let input = Matrix::from_vec(vec![-1.0, 0.0, 1.0], 3, 1);
|
||||
let expected = Matrix::from_vec(vec![0.41997434161, 1.0, 0.41997434161], 3, 1); // 1 - tanh(x)^2
|
||||
let output = ActivationKind::Tanh.derivative(&input);
|
||||
|
||||
for i in 0..input.rows() {
|
||||
for j in 0..input.cols() {
|
||||
// "Tanh derivative output mismatch at ({}, {})"
|
||||
assert!((output[(i, j)] - expected[(i, j)]).abs() < 1e-9);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_initializer_kind_xavier() {
|
||||
let rows = 10;
|
||||
let cols = 20;
|
||||
let initializer = InitializerKind::Xavier;
|
||||
let matrix = initializer.initialize(rows, cols);
|
||||
let limit = (6.0 / (rows + cols) as f64).sqrt();
|
||||
|
||||
assert_eq!(matrix.rows(), rows);
|
||||
assert_eq!(matrix.cols(), cols);
|
||||
|
||||
for val in matrix.data() {
|
||||
// Xavier initialized value out of range
|
||||
assert!(*val >= -limit && *val <= limit);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_initializer_kind_he() {
|
||||
let rows = 10;
|
||||
let cols = 20;
|
||||
let initializer = InitializerKind::He;
|
||||
let matrix = initializer.initialize(rows, cols);
|
||||
let limit = (2.0 / rows as f64).sqrt();
|
||||
|
||||
assert_eq!(matrix.rows(), rows);
|
||||
assert_eq!(matrix.cols(), cols);
|
||||
|
||||
for val in matrix.data() {
|
||||
// He initialized value out of range
|
||||
assert!(*val >= -limit && *val <= limit);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_loss_kind_bce_gradient() {
|
||||
let y_hat = Matrix::from_vec(vec![0.1, 0.9, 0.4], 3, 1);
|
||||
let y = Matrix::from_vec(vec![0.0, 1.0, 0.5], 3, 1);
|
||||
let expected_gradient = Matrix::from_vec(vec![0.1 / 3.0, -0.1 / 3.0, -0.1 / 3.0], 3, 1); // (y_hat - y) * (1.0 / m)
|
||||
let output_gradient = LossKind::BCE.gradient(&y_hat, &y);
|
||||
|
||||
assert_eq!(output_gradient.rows(), expected_gradient.rows());
|
||||
assert_eq!(output_gradient.cols(), expected_gradient.cols());
|
||||
|
||||
for i in 0..output_gradient.rows() {
|
||||
for j in 0..output_gradient.cols() {
|
||||
// BCE gradient output mismatch at ({}, {})
|
||||
assert!((output_gradient[(i, j)] - expected_gradient[(i, j)]).abs() < 1e-9);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_training_reduces_bce_loss() {
|
||||
// Single-layer sigmoid setup; check BCE loss goes down.
|
||||
let config = DenseNNConfig {
|
||||
input_size: 1,
|
||||
hidden_layers: vec![],
|
||||
activations: vec![ActivationKind::Sigmoid],
|
||||
output_size: 1,
|
||||
initializer: InitializerKind::Uniform(0.1),
|
||||
loss: LossKind::BCE,
|
||||
learning_rate: 1.0,
|
||||
epochs: 10,
|
||||
};
|
||||
let mut model = DenseNN::new(config);
|
||||
|
||||
let x = Matrix::from_vec(vec![0.0, 1.0, 0.5], 3, 1);
|
||||
let y = Matrix::from_vec(vec![0.0, 1.0, 0.5], 3, 1);
|
||||
|
||||
let before_preds = model.predict(&x);
|
||||
// BCE loss calculation for testing
|
||||
let before_loss = -1.0 / (y.rows() as f64)
|
||||
* before_preds
|
||||
.zip(&y, |yh, yv| yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln())
|
||||
.data()
|
||||
.iter()
|
||||
.sum::<f64>();
|
||||
|
||||
model.train(&x, &y);
|
||||
|
||||
let after_preds = model.predict(&x);
|
||||
let after_loss = -1.0 / (y.rows() as f64)
|
||||
* after_preds
|
||||
.zip(&y, |yh, yv| yv * yh.ln() + (1.0 - yv) * (1.0 - yh).ln())
|
||||
.data()
|
||||
.iter()
|
||||
.sum::<f64>();
|
||||
|
||||
// BCE did not decrease (before: {}, after: {})
|
||||
assert!(after_loss < before_loss,);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_train_backprop_delta_propagation() {
|
||||
// Network with two layers to test delta propagation to previous layer (l > 0)
|
||||
let config = DenseNNConfig {
|
||||
input_size: 2,
|
||||
hidden_layers: vec![3],
|
||||
activations: vec![ActivationKind::Sigmoid, ActivationKind::Sigmoid],
|
||||
output_size: 1,
|
||||
initializer: InitializerKind::Uniform(0.1),
|
||||
loss: LossKind::MSE,
|
||||
learning_rate: 0.1,
|
||||
epochs: 1,
|
||||
};
|
||||
let mut model = DenseNN::new(config);
|
||||
|
||||
// Store initial weights and biases to compare after training
|
||||
let initial_weights_l0 = model.weights[0].clone();
|
||||
let initial_biases_l0 = model.biases[0].clone();
|
||||
let initial_weights_l1 = model.weights[1].clone();
|
||||
let initial_biases_l1 = model.biases[1].clone();
|
||||
|
||||
let x = Matrix::from_vec(vec![0.1, 0.2, 0.3, 0.4], 2, 2);
|
||||
let y = Matrix::from_vec(vec![0.5, 0.6], 2, 1);
|
||||
|
||||
model.train(&x, &y);
|
||||
|
||||
// Verify that weights and biases of both layers have changed,
|
||||
// implying delta propagation occurred for l > 0
|
||||
|
||||
// Weights of first layer did not change, delta propagation might not have occurred
|
||||
assert!(model.weights[0] != initial_weights_l0);
|
||||
// Biases of first layer did not change, delta propagation might not have occurred
|
||||
assert!(model.biases[0] != initial_biases_l0);
|
||||
// Weights of second layer did not change
|
||||
assert!(model.weights[1] != initial_weights_l1);
|
||||
// Biases of second layer did not change
|
||||
assert!(model.biases[1] != initial_biases_l1);
|
||||
}
|
||||
}
|
||||
243
src/compute/models/gaussian_nb.rs
Normal file
243
src/compute/models/gaussian_nb.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
//! Gaussian Naive Bayes classifier for dense matrices.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::models::gaussian_nb::GaussianNB;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let x = Matrix::from_vec(vec![1.0, 2.0, 1.0, 2.0], 2, 2); // two samples
|
||||
//! let y = Matrix::from_vec(vec![0.0, 1.0], 2, 1);
|
||||
//! let mut model = GaussianNB::new(1e-9, false);
|
||||
//! model.fit(&x, &y);
|
||||
//! let preds = model.predict(&x);
|
||||
//! assert_eq!(preds.rows(), 2);
|
||||
//! ```
|
||||
use crate::matrix::Matrix;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A Gaussian Naive Bayes classifier.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `var_smoothing`: Portion of the largest variance of all features to add to variances for stability.
|
||||
/// - `use_unbiased_variance`: If `true`, uses Bessel's correction (dividing by (n-1)); otherwise divides by n.
|
||||
///
|
||||
pub struct GaussianNB {
|
||||
// Distinct class labels
|
||||
classes: Vec<f64>,
|
||||
// Prior probabilities P(class)
|
||||
priors: Vec<f64>,
|
||||
// Feature means per class
|
||||
means: Vec<Matrix<f64>>,
|
||||
// Feature variances per class
|
||||
variances: Vec<Matrix<f64>>,
|
||||
// var_smoothing
|
||||
eps: f64,
|
||||
// flag for unbiased variance
|
||||
use_unbiased: bool,
|
||||
}
|
||||
|
||||
impl GaussianNB {
|
||||
/// Create a new GaussianNB.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `var_smoothing` - small float added to variances for numerical stability.
|
||||
/// * `use_unbiased_variance` - whether to apply Bessel's correction (divide by n-1).
|
||||
pub fn new(var_smoothing: f64, use_unbiased_variance: bool) -> Self {
|
||||
Self {
|
||||
classes: Vec::new(),
|
||||
priors: Vec::new(),
|
||||
means: Vec::new(),
|
||||
variances: Vec::new(),
|
||||
eps: var_smoothing,
|
||||
use_unbiased: use_unbiased_variance,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fit the model according to the training data `x` and labels `y`.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if `x` or `y` is empty, or if their dimensions disagree.
|
||||
pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>) {
|
||||
let m = x.rows();
|
||||
let n = x.cols();
|
||||
assert_eq!(y.rows(), m, "Row count of X and Y must match");
|
||||
assert_eq!(y.cols(), 1, "Y must be a column vector");
|
||||
if m == 0 || n == 0 {
|
||||
panic!("Input matrix x or y is empty");
|
||||
}
|
||||
|
||||
// Group sample indices by label
|
||||
let mut groups: HashMap<u64, Vec<usize>> = HashMap::new();
|
||||
for i in 0..m {
|
||||
let label = y[(i, 0)];
|
||||
let bits = label.to_bits();
|
||||
groups.entry(bits).or_default().push(i);
|
||||
}
|
||||
assert!(!groups.is_empty(), "No class labels found in y"); //-- panicked earlier
|
||||
|
||||
// Extract and sort class labels
|
||||
self.classes = groups.keys().cloned().map(f64::from_bits).collect();
|
||||
self.classes.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
|
||||
self.priors.clear();
|
||||
self.means.clear();
|
||||
self.variances.clear();
|
||||
|
||||
// Precompute max variance for smoothing scale
|
||||
let mut max_var_feature = 0.0;
|
||||
for j in 0..n {
|
||||
let mut col_vals = Vec::with_capacity(m);
|
||||
for i in 0..m {
|
||||
col_vals.push(x[(i, j)]);
|
||||
}
|
||||
let mean_all = col_vals.iter().sum::<f64>() / m as f64;
|
||||
let var_all = col_vals.iter().map(|v| (v - mean_all).powi(2)).sum::<f64>() / m as f64;
|
||||
if var_all > max_var_feature {
|
||||
max_var_feature = var_all;
|
||||
}
|
||||
}
|
||||
let smoothing = self.eps * max_var_feature;
|
||||
|
||||
// Compute per-class statistics
|
||||
for &c in &self.classes {
|
||||
let idx = &groups[&c.to_bits()];
|
||||
let count = idx.len();
|
||||
// Prior
|
||||
self.priors.push(count as f64 / m as f64);
|
||||
|
||||
let mut mean = Matrix::zeros(1, n);
|
||||
let mut var = Matrix::zeros(1, n);
|
||||
|
||||
// Mean
|
||||
for &i in idx {
|
||||
for j in 0..n {
|
||||
mean[(0, j)] += x[(i, j)];
|
||||
}
|
||||
}
|
||||
for j in 0..n {
|
||||
mean[(0, j)] /= count as f64;
|
||||
}
|
||||
|
||||
// Variance
|
||||
for &i in idx {
|
||||
for j in 0..n {
|
||||
let d = x[(i, j)] - mean[(0, j)];
|
||||
var[(0, j)] += d * d;
|
||||
}
|
||||
}
|
||||
let denom = if self.use_unbiased {
|
||||
(count as f64 - 1.0).max(1.0)
|
||||
} else {
|
||||
count as f64
|
||||
};
|
||||
for j in 0..n {
|
||||
var[(0, j)] = var[(0, j)] / denom + smoothing;
|
||||
if var[(0, j)] <= 0.0 {
|
||||
var[(0, j)] = smoothing;
|
||||
}
|
||||
}
|
||||
|
||||
self.means.push(mean);
|
||||
self.variances.push(var);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform classification on an array of test vectors `x`.
|
||||
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
||||
let m = x.rows();
|
||||
let n = x.cols();
|
||||
let k = self.classes.len();
|
||||
let mut preds = Matrix::zeros(m, 1);
|
||||
let ln_2pi = (2.0 * std::f64::consts::PI).ln();
|
||||
|
||||
for i in 0..m {
|
||||
let mut best = (0, f64::NEG_INFINITY);
|
||||
for c_idx in 0..k {
|
||||
let mut log_prob = self.priors[c_idx].ln();
|
||||
for j in 0..n {
|
||||
let diff = x[(i, j)] - self.means[c_idx][(0, j)];
|
||||
let var = self.variances[c_idx][(0, j)];
|
||||
log_prob += -0.5 * (diff * diff / var + var.ln() + ln_2pi);
|
||||
}
|
||||
if log_prob > best.1 {
|
||||
best = (c_idx, log_prob);
|
||||
}
|
||||
}
|
||||
preds[(i, 0)] = self.classes[best.0];
|
||||
}
|
||||
preds
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::matrix::Matrix;
|
||||
|
||||
#[test]
|
||||
fn test_simple_two_class() {
|
||||
// Simple dataset: one feature, two classes 0 and 1
|
||||
// Class 0: values [1.0, 1.2, 0.8]
|
||||
// Class 1: values [3.0, 3.2, 2.8]
|
||||
let x = Matrix::from_vec(vec![1.0, 1.2, 0.8, 3.0, 3.2, 2.8], 6, 1);
|
||||
let y = Matrix::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0], 6, 1);
|
||||
let mut clf = GaussianNB::new(1e-9, false);
|
||||
clf.fit(&x, &y);
|
||||
let test = Matrix::from_vec(vec![1.1, 3.1], 2, 1);
|
||||
let preds = clf.predict(&test);
|
||||
assert_eq!(preds[(0, 0)], 0.0);
|
||||
assert_eq!(preds[(1, 0)], 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unbiased_variance() {
|
||||
// Same as above but with unbiased variance
|
||||
let x = Matrix::from_vec(vec![2.0, 2.2, 1.8, 4.0, 4.2, 3.8], 6, 1);
|
||||
let y = Matrix::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0], 6, 1);
|
||||
let mut clf = GaussianNB::new(1e-9, true);
|
||||
clf.fit(&x, &y);
|
||||
let test = Matrix::from_vec(vec![2.1, 4.1], 2, 1);
|
||||
let preds = clf.predict(&test);
|
||||
assert_eq!(preds[(0, 0)], 0.0);
|
||||
assert_eq!(preds[(1, 0)], 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_empty_input() {
|
||||
let x = Matrix::zeros(0, 0);
|
||||
let y = Matrix::zeros(0, 1);
|
||||
let mut clf = GaussianNB::new(1e-9, false);
|
||||
clf.fit(&x, &y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Row count of X and Y must match"]
|
||||
fn test_mismatched_rows() {
|
||||
let x = Matrix::from_vec(vec![1.0, 2.0], 2, 1);
|
||||
let y = Matrix::from_vec(vec![0.0], 1, 1);
|
||||
let mut clf = GaussianNB::new(1e-9, false);
|
||||
clf.fit(&x, &y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_variance_smoothing_override_with_zero_smoothing() {
|
||||
// Scenario: var_smoothing is 0, and a feature has zero variance within a class.
|
||||
// This should trigger the `if var[(0, j)] <= 0.0 { var[(0, j)] = smoothing; }` line.
|
||||
let x = Matrix::from_vec(vec![1.0, 1.0, 2.0], 3, 1); // Class 0: [1.0, 1.0], Class 1: [2.0]
|
||||
let y = Matrix::from_vec(vec![0.0, 0.0, 1.0], 3, 1);
|
||||
let mut clf = GaussianNB::new(0.0, false); // var_smoothing = 0.0
|
||||
clf.fit(&x, &y);
|
||||
|
||||
// For class 0 (index 0 in clf.classes), the feature (index 0) had values [1.0, 1.0], so variance was 0.
|
||||
// Since var_smoothing was 0, smoothing is 0.
|
||||
// The line `var[(0, j)] = smoothing;` should have set the variance to 0.0.
|
||||
let class_0_idx = clf.classes.iter().position(|&c| c == 0.0).unwrap();
|
||||
assert_eq!(clf.variances[class_0_idx][(0, 0)], 0.0);
|
||||
|
||||
// For class 1 (index 1 in clf.classes), the feature (index 0) had value [2.0].
|
||||
// Variance calculation for a single point results in 0.
|
||||
// The if condition will be true, and var[(0, j)] will be set to smoothing (0.0).
|
||||
let class_1_idx = clf.classes.iter().position(|&c| c == 1.0).unwrap();
|
||||
assert_eq!(clf.variances[class_1_idx][(0, 0)], 0.0);
|
||||
}
|
||||
}
|
||||
374
src/compute/models/k_means.rs
Normal file
374
src/compute/models/k_means.rs
Normal file
@@ -0,0 +1,374 @@
|
||||
//! Simple k-means clustering working on [`Matrix`] data.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::models::k_means::KMeans;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let data = Matrix::from_vec(vec![1.0, 1.0, 5.0, 5.0], 2, 2);
|
||||
//! let (model, labels) = KMeans::fit(&data, 2, 10, 1e-4);
|
||||
//! assert_eq!(model.centroids.rows(), 2);
|
||||
//! assert_eq!(labels.len(), 2);
|
||||
//! ```
|
||||
use crate::compute::stats::mean_vertical;
|
||||
use crate::matrix::Matrix;
|
||||
use crate::random::prelude::*;
|
||||
|
||||
pub struct KMeans {
|
||||
pub centroids: Matrix<f64>, // (k, n_features)
|
||||
}
|
||||
|
||||
impl KMeans {
|
||||
/// Fit with k clusters.
|
||||
pub fn fit(x: &Matrix<f64>, k: usize, max_iter: usize, tol: f64) -> (Self, Vec<usize>) {
|
||||
let m = x.rows();
|
||||
let n = x.cols();
|
||||
assert!(k <= m, "k must be ≤ number of samples");
|
||||
|
||||
// ----- initialise centroids -----
|
||||
let mut centroids = Matrix::zeros(k, n);
|
||||
if k > 0 && m > 0 {
|
||||
// case for empty data
|
||||
if k == 1 {
|
||||
let mean = mean_vertical(x);
|
||||
centroids.row_copy_from_slice(0, &mean.data()); // ideally, data.row(0), but thats the same
|
||||
} else {
|
||||
// For k > 1, pick k distinct rows at random
|
||||
let mut rng = rng();
|
||||
let mut indices: Vec<usize> = (0..m).collect();
|
||||
indices.shuffle(&mut rng);
|
||||
for c in 0..k {
|
||||
centroids.row_copy_from_slice(c, &x.row(indices[c]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut labels = vec![0usize; m];
|
||||
let mut distances = vec![0.0f64; m];
|
||||
|
||||
for _iter in 0..max_iter {
|
||||
let mut changed = false;
|
||||
// ----- assignment step -----
|
||||
for i in 0..m {
|
||||
let sample_row = x.row(i);
|
||||
let mut best = 0usize;
|
||||
let mut best_dist_sq = f64::MAX;
|
||||
|
||||
for c in 0..k {
|
||||
let centroid_row = centroids.row(c);
|
||||
|
||||
let dist_sq: f64 = sample_row
|
||||
.iter()
|
||||
.zip(centroid_row.iter())
|
||||
.map(|(a, b)| (a - b).powi(2))
|
||||
.sum();
|
||||
|
||||
if dist_sq < best_dist_sq {
|
||||
best_dist_sq = dist_sq;
|
||||
best = c;
|
||||
}
|
||||
}
|
||||
|
||||
distances[i] = best_dist_sq;
|
||||
|
||||
if labels[i] != best {
|
||||
labels[i] = best;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
// ----- update step -----
|
||||
let mut new_centroids = Matrix::zeros(k, n);
|
||||
let mut counts = vec![0usize; k];
|
||||
for i in 0..m {
|
||||
let c = labels[i];
|
||||
counts[c] += 1;
|
||||
for j in 0..n {
|
||||
new_centroids[(c, j)] += x[(i, j)];
|
||||
}
|
||||
}
|
||||
|
||||
for c in 0..k {
|
||||
if counts[c] == 0 {
|
||||
// This cluster is empty. Re-initialize its centroid to the point
|
||||
// furthest from its assigned centroid to prevent the cluster from dying.
|
||||
let mut furthest_point_idx = 0;
|
||||
let mut max_dist_sq = 0.0;
|
||||
for (i, &dist) in distances.iter().enumerate() {
|
||||
if dist > max_dist_sq {
|
||||
max_dist_sq = dist;
|
||||
furthest_point_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
for j in 0..n {
|
||||
new_centroids[(c, j)] = x[(furthest_point_idx, j)];
|
||||
}
|
||||
// Ensure this point isn't chosen again for another empty cluster in the same iteration.
|
||||
if m > 0 {
|
||||
distances[furthest_point_idx] = 0.0;
|
||||
}
|
||||
} else {
|
||||
// Normalize the centroid by the number of points in it.
|
||||
for j in 0..n {
|
||||
new_centroids[(c, j)] /= counts[c] as f64;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----- convergence test -----
|
||||
if !changed {
|
||||
centroids = new_centroids; // update before breaking
|
||||
break; // assignments stable
|
||||
}
|
||||
|
||||
let diff = &new_centroids - ¢roids;
|
||||
centroids = new_centroids; // Update for the next iteration
|
||||
|
||||
if tol > 0.0 {
|
||||
let sq_diff = &diff * &diff;
|
||||
let shift = sq_diff.data().iter().sum::<f64>().sqrt();
|
||||
if shift < tol {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
(Self { centroids }, labels)
|
||||
}
|
||||
|
||||
/// Predict nearest centroid for each sample.
|
||||
pub fn predict(&self, x: &Matrix<f64>) -> Vec<usize> {
|
||||
let m = x.rows();
|
||||
let k = self.centroids.rows();
|
||||
|
||||
if m == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut labels = vec![0usize; m];
|
||||
for i in 0..m {
|
||||
let sample_row = x.row(i);
|
||||
let mut best = 0usize;
|
||||
let mut best_dist_sq = f64::MAX;
|
||||
|
||||
for c in 0..k {
|
||||
let centroid_row = self.centroids.row(c);
|
||||
|
||||
let dist_sq: f64 = sample_row
|
||||
.iter()
|
||||
.zip(centroid_row.iter())
|
||||
.map(|(a, b)| (a - b).powi(2))
|
||||
.sum();
|
||||
|
||||
if dist_sq < best_dist_sq {
|
||||
best_dist_sq = dist_sq;
|
||||
best = c;
|
||||
}
|
||||
}
|
||||
labels[i] = best;
|
||||
}
|
||||
labels
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn test_k_means_empty_cluster_reinit_centroid() {
|
||||
// Try multiple times to increase the chance of hitting the empty cluster case
|
||||
for _ in 0..20 {
|
||||
let data = vec![0.0, 0.0, 0.0, 0.0, 10.0, 10.0];
|
||||
let x = FloatMatrix::from_rows_vec(data, 3, 2);
|
||||
let k = 2;
|
||||
let max_iter = 10;
|
||||
let tol = 1e-6;
|
||||
|
||||
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||
|
||||
// Check if any cluster is empty
|
||||
let mut counts = vec![0; k];
|
||||
for &label in &labels {
|
||||
counts[label] += 1;
|
||||
}
|
||||
if counts.iter().any(|&c| c == 0) {
|
||||
// Only check the property for clusters that are empty
|
||||
let centroids = kmeans_model.centroids;
|
||||
for c in 0..k {
|
||||
if counts[c] == 0 {
|
||||
let mut matches_data_point = false;
|
||||
for i in 0..3 {
|
||||
let dx = centroids[(c, 0)] - x[(i, 0)];
|
||||
let dy = centroids[(c, 1)] - x[(i, 1)];
|
||||
if dx.abs() < 1e-9 && dy.abs() < 1e-9 {
|
||||
matches_data_point = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// "Centroid {} (empty cluster) does not match any data point",c
|
||||
assert!(matches_data_point);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
// If we never saw an empty cluster, that's fine; the test passes as long as no panic occurred
|
||||
}
|
||||
use super::*;
|
||||
use crate::matrix::FloatMatrix;
|
||||
|
||||
fn create_test_data() -> (FloatMatrix, usize) {
|
||||
// Simple 2D data for testing K-Means
|
||||
// Cluster 1: (1,1), (1.5,1.5)
|
||||
// Cluster 2: (5,8), (8,8), (6,7)
|
||||
let data = vec![
|
||||
1.0, 1.0, // Sample 0
|
||||
1.5, 1.5, // Sample 1
|
||||
5.0, 8.0, // Sample 2
|
||||
8.0, 8.0, // Sample 3
|
||||
6.0, 7.0, // Sample 4
|
||||
];
|
||||
let x = FloatMatrix::from_rows_vec(data, 5, 2);
|
||||
let k = 2;
|
||||
(x, k)
|
||||
}
|
||||
|
||||
// Helper for single cluster test with exact mean
|
||||
fn create_simple_integer_data() -> FloatMatrix {
|
||||
// Data points: (1,1), (2,2), (3,3)
|
||||
FloatMatrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_k_means_fit_predict_basic() {
|
||||
let (x, k) = create_test_data();
|
||||
let max_iter = 100;
|
||||
let tol = 1e-6;
|
||||
|
||||
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||
|
||||
// Assertions for fit
|
||||
assert_eq!(kmeans_model.centroids.rows(), k);
|
||||
assert_eq!(kmeans_model.centroids.cols(), x.cols());
|
||||
assert_eq!(labels.len(), x.rows());
|
||||
|
||||
// Check if labels are within expected range (0 to k-1)
|
||||
for &label in &labels {
|
||||
assert!(label < k);
|
||||
}
|
||||
|
||||
// Predict with the same data
|
||||
let predicted_labels = kmeans_model.predict(&x);
|
||||
|
||||
// The exact labels might vary due to random initialization,
|
||||
// but the clustering should be consistent.
|
||||
// We expect two clusters. Let's check if samples 0,1 are in one cluster
|
||||
// and samples 2,3,4 are in another.
|
||||
let cluster_0_members = vec![labels[0], labels[1]];
|
||||
let cluster_1_members = vec![labels[2], labels[3], labels[4]];
|
||||
|
||||
// All members of cluster 0 should have the same label
|
||||
assert_eq!(cluster_0_members[0], cluster_0_members[1]);
|
||||
// All members of cluster 1 should have the same label
|
||||
assert_eq!(cluster_1_members[0], cluster_1_members[1]);
|
||||
assert_eq!(cluster_1_members[0], cluster_1_members[2]);
|
||||
// The two clusters should have different labels
|
||||
assert_ne!(cluster_0_members[0], cluster_1_members[0]);
|
||||
|
||||
// Check predicted labels are consistent with fitted labels
|
||||
assert_eq!(labels, predicted_labels);
|
||||
|
||||
// Test with a new sample
|
||||
let new_sample_data = vec![1.2, 1.3]; // Should be close to cluster 0
|
||||
let new_sample = FloatMatrix::from_rows_vec(new_sample_data, 1, 2);
|
||||
let new_sample_label = kmeans_model.predict(&new_sample)[0];
|
||||
assert_eq!(new_sample_label, cluster_0_members[0]);
|
||||
|
||||
let new_sample_data_2 = vec![7.0, 7.5]; // Should be close to cluster 1
|
||||
let new_sample_2 = FloatMatrix::from_rows_vec(new_sample_data_2, 1, 2);
|
||||
let new_sample_label_2 = kmeans_model.predict(&new_sample_2)[0];
|
||||
assert_eq!(new_sample_label_2, cluster_1_members[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_k_means_fit_k_equals_m() {
|
||||
// Test case where k (number of clusters) equals m (number of samples)
|
||||
let (x, _) = create_test_data(); // 5 samples
|
||||
let k = 5; // 5 clusters
|
||||
let max_iter = 10;
|
||||
let tol = 1e-6;
|
||||
|
||||
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||
|
||||
assert_eq!(kmeans_model.centroids.rows(), k);
|
||||
assert_eq!(labels.len(), x.rows());
|
||||
|
||||
// Each sample should be its own cluster. Due to random init, labels
|
||||
// might not be [0,1,2,3,4] but will be a permutation of it.
|
||||
let mut sorted_labels = labels.clone();
|
||||
sorted_labels.sort_unstable();
|
||||
sorted_labels.dedup();
|
||||
// Labels should all be unique when k==m
|
||||
assert_eq!(sorted_labels.len(), k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "k must be ≤ number of samples")]
|
||||
fn test_k_means_fit_k_greater_than_m() {
|
||||
let (x, _) = create_test_data(); // 5 samples
|
||||
let k = 6; // k > m
|
||||
let max_iter = 10;
|
||||
let tol = 1e-6;
|
||||
|
||||
let (_kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_k_means_fit_single_cluster() {
|
||||
// Test with k=1
|
||||
let x = create_simple_integer_data(); // Use integer data
|
||||
let k = 1;
|
||||
let max_iter = 100;
|
||||
let tol = 1e-6;
|
||||
|
||||
let (kmeans_model, labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||
|
||||
assert_eq!(kmeans_model.centroids.rows(), 1);
|
||||
assert_eq!(labels.len(), x.rows());
|
||||
|
||||
// All labels should be 0
|
||||
assert!(labels.iter().all(|&l| l == 0));
|
||||
|
||||
// Centroid should be the mean of all data points
|
||||
let expected_centroid_x = x.column(0).iter().sum::<f64>() / x.rows() as f64;
|
||||
let expected_centroid_y = x.column(1).iter().sum::<f64>() / x.rows() as f64;
|
||||
|
||||
assert!((kmeans_model.centroids[(0, 0)] - expected_centroid_x).abs() < 1e-9);
|
||||
assert!((kmeans_model.centroids[(0, 1)] - expected_centroid_y).abs() < 1e-9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_k_means_predict_empty_matrix() {
|
||||
let (x, k) = create_test_data();
|
||||
let max_iter = 10;
|
||||
let tol = 1e-6;
|
||||
let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||
|
||||
// The `Matrix` type not support 0xN or Nx0 matrices.
|
||||
// test with a 0x0 matrix is a valid edge case.
|
||||
let empty_x = FloatMatrix::from_rows_vec(vec![], 0, 0);
|
||||
let predicted_labels = kmeans_model.predict(&empty_x);
|
||||
assert!(predicted_labels.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_k_means_predict_single_sample() {
|
||||
let (x, k) = create_test_data();
|
||||
let max_iter = 10;
|
||||
let tol = 1e-6;
|
||||
let (kmeans_model, _labels) = KMeans::fit(&x, k, max_iter, tol);
|
||||
|
||||
let single_sample = FloatMatrix::from_rows_vec(vec![1.1, 1.2], 1, 2);
|
||||
let predicted_label = kmeans_model.predict(&single_sample);
|
||||
assert_eq!(predicted_label.len(), 1);
|
||||
assert!(predicted_label[0] < k);
|
||||
}
|
||||
}
|
||||
67
src/compute/models/linreg.rs
Normal file
67
src/compute/models/linreg.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
//! Ordinary least squares linear regression.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::models::linreg::LinReg;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
|
||||
//! let y = Matrix::from_vec(vec![2.0, 3.0, 4.0, 5.0], 4, 1);
|
||||
//! let mut model = LinReg::new(1);
|
||||
//! model.fit(&x, &y, 0.01, 100);
|
||||
//! let preds = model.predict(&x);
|
||||
//! assert_eq!(preds.rows(), 4);
|
||||
//! ```
|
||||
use crate::matrix::{Matrix, SeriesOps};
|
||||
|
||||
pub struct LinReg {
|
||||
w: Matrix<f64>, // shape (n_features, 1)
|
||||
b: f64,
|
||||
}
|
||||
|
||||
impl LinReg {
|
||||
pub fn new(n_features: usize) -> Self {
|
||||
Self {
|
||||
w: Matrix::from_vec(vec![0.0; n_features], n_features, 1),
|
||||
b: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
||||
// X.dot(w) + b
|
||||
x.dot(&self.w) + self.b
|
||||
}
|
||||
|
||||
pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>, lr: f64, epochs: usize) {
|
||||
let m = x.rows() as f64;
|
||||
for _ in 0..epochs {
|
||||
let y_hat = self.predict(x);
|
||||
let err = &y_hat - y; // shape (m,1)
|
||||
|
||||
// grads
|
||||
let grad_w = x.transpose().dot(&err) * (2.0 / m); // (n,1)
|
||||
let grad_b = (2.0 / m) * err.sum_vertical().iter().sum::<f64>();
|
||||
// update
|
||||
self.w = &self.w - &(grad_w * lr);
|
||||
self.b -= lr * grad_b;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_linreg_fit_predict() {
|
||||
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
|
||||
let y = Matrix::from_vec(vec![2.0, 3.0, 4.0, 5.0], 4, 1);
|
||||
let mut model = LinReg::new(1);
|
||||
model.fit(&x, &y, 0.01, 10000);
|
||||
let preds = model.predict(&x);
|
||||
assert!((preds[(0, 0)] - 2.0).abs() < 1e-2);
|
||||
assert!((preds[(1, 0)] - 3.0).abs() < 1e-2);
|
||||
assert!((preds[(2, 0)] - 4.0).abs() < 1e-2);
|
||||
assert!((preds[(3, 0)] - 5.0).abs() < 1e-2);
|
||||
}
|
||||
}
|
||||
68
src/compute/models/logreg.rs
Normal file
68
src/compute/models/logreg.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
//! Binary logistic regression classifier.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::models::logreg::LogReg;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
|
||||
//! let y = Matrix::from_vec(vec![0.0, 0.0, 1.0, 1.0], 4, 1);
|
||||
//! let mut model = LogReg::new(1);
|
||||
//! model.fit(&x, &y, 0.1, 100);
|
||||
//! let preds = model.predict(&x);
|
||||
//! assert_eq!(preds[(0,0)], 0.0);
|
||||
//! ```
|
||||
use crate::compute::models::activations::sigmoid;
|
||||
use crate::matrix::{Matrix, SeriesOps};
|
||||
|
||||
pub struct LogReg {
|
||||
w: Matrix<f64>,
|
||||
b: f64,
|
||||
}
|
||||
|
||||
impl LogReg {
|
||||
pub fn new(n_features: usize) -> Self {
|
||||
Self {
|
||||
w: Matrix::zeros(n_features, 1),
|
||||
b: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn predict_proba(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
||||
sigmoid(&(x.dot(&self.w) + self.b)) // σ(Xw + b)
|
||||
}
|
||||
|
||||
pub fn fit(&mut self, x: &Matrix<f64>, y: &Matrix<f64>, lr: f64, epochs: usize) {
|
||||
let m = x.rows() as f64;
|
||||
for _ in 0..epochs {
|
||||
let p = self.predict_proba(x); // shape (m,1)
|
||||
let err = &p - y; // derivative of BCE wrt pre-sigmoid
|
||||
let grad_w = x.transpose().dot(&err) / m;
|
||||
let grad_b = err.sum_vertical().iter().sum::<f64>() / m;
|
||||
self.w = &self.w - &(grad_w * lr);
|
||||
self.b -= lr * grad_b;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn predict(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
||||
self.predict_proba(x)
|
||||
.map(|p| if p >= 0.5 { 1.0 } else { 0.0 })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_logreg_fit_predict() {
|
||||
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
|
||||
let y = Matrix::from_vec(vec![0.0, 0.0, 1.0, 1.0], 4, 1);
|
||||
let mut model = LogReg::new(1);
|
||||
model.fit(&x, &y, 0.01, 10000);
|
||||
let preds = model.predict(&x);
|
||||
assert_eq!(preds[(0, 0)], 0.0);
|
||||
assert_eq!(preds[(1, 0)], 0.0);
|
||||
assert_eq!(preds[(2, 0)], 1.0);
|
||||
assert_eq!(preds[(3, 0)], 1.0);
|
||||
}
|
||||
}
|
||||
23
src/compute/models/mod.rs
Normal file
23
src/compute/models/mod.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
//! Lightweight machine‑learning models built on matrices.
|
||||
//!
|
||||
//! Models are intentionally minimal and operate on the [`Matrix`](crate::matrix::Matrix) type for
|
||||
//! inputs and parameters.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::models::linreg::LinReg;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 4, 1);
|
||||
//! let y = Matrix::from_vec(vec![2.0, 3.0, 4.0, 5.0], 4, 1);
|
||||
//! let mut model = LinReg::new(1);
|
||||
//! model.fit(&x, &y, 0.01, 1000);
|
||||
//! let preds = model.predict(&x);
|
||||
//! assert_eq!(preds.rows(), 4);
|
||||
//! ```
|
||||
pub mod activations;
|
||||
pub mod dense_nn;
|
||||
pub mod gaussian_nb;
|
||||
pub mod k_means;
|
||||
pub mod linreg;
|
||||
pub mod logreg;
|
||||
pub mod pca;
|
||||
113
src/compute/models/pca.rs
Normal file
113
src/compute/models/pca.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
//! Principal Component Analysis using covariance matrices.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::models::pca::PCA;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let data = Matrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0], 2, 2);
|
||||
//! let pca = PCA::fit(&data, 1, 0);
|
||||
//! let projected = pca.transform(&data);
|
||||
//! assert_eq!(projected.cols(), 1);
|
||||
//! ```
|
||||
use crate::compute::stats::correlation::covariance_matrix;
|
||||
use crate::compute::stats::descriptive::mean_vertical;
|
||||
use crate::matrix::{Axis, Matrix, SeriesOps};
|
||||
|
||||
/// Returns the `n_components` principal axes (rows) and the centred data's mean.
|
||||
pub struct PCA {
|
||||
pub components: Matrix<f64>, // (n_components, n_features)
|
||||
pub mean: Matrix<f64>, // (1, n_features)
|
||||
}
|
||||
|
||||
impl PCA {
|
||||
pub fn fit(x: &Matrix<f64>, n_components: usize, _iters: usize) -> Self {
|
||||
let mean = mean_vertical(x); // Mean of each feature (column)
|
||||
let broadcasted_mean = mean.broadcast_row_to_target_shape(x.rows(), x.cols());
|
||||
let centered_data = x.zip(&broadcasted_mean, |x_i, mean_i| x_i - mean_i);
|
||||
let covariance_matrix = covariance_matrix(¢ered_data, Axis::Col); // Covariance between features
|
||||
|
||||
let mut components = Matrix::zeros(n_components, x.cols());
|
||||
for i in 0..n_components {
|
||||
if i < covariance_matrix.rows() {
|
||||
components.row_copy_from_slice(i, &covariance_matrix.row(i));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
PCA { components, mean }
|
||||
}
|
||||
|
||||
/// Project new data on the learned axes.
|
||||
pub fn transform(&self, x: &Matrix<f64>) -> Matrix<f64> {
|
||||
let broadcasted_mean = self.mean.broadcast_row_to_target_shape(x.rows(), x.cols());
|
||||
let centered_data = x.zip(&broadcasted_mean, |x_i, mean_i| x_i - mean_i);
|
||||
centered_data.matrix_mul(&self.components.transpose())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::matrix::Matrix;
|
||||
|
||||
const EPSILON: f64 = 1e-8;
|
||||
|
||||
#[test]
|
||||
fn test_pca_basic() {
|
||||
// Simple 2D data with points along the y = x line
|
||||
let data = Matrix::from_rows_vec(vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0], 3, 2);
|
||||
let (_n_samples, _n_features) = data.shape();
|
||||
|
||||
let pca = PCA::fit(&data, 1, 0); // n_components = 1, iters is unused
|
||||
|
||||
println!("Data shape: {:?}", data.shape());
|
||||
println!("PCA mean shape: {:?}", pca.mean.shape());
|
||||
println!("PCA components shape: {:?}", pca.components.shape());
|
||||
|
||||
// Expected mean: (2.0, 2.0)
|
||||
assert!((pca.mean.get(0, 0) - 2.0).abs() < EPSILON);
|
||||
assert!((pca.mean.get(0, 1) - 2.0).abs() < EPSILON);
|
||||
|
||||
// For data along y=x, the principal component should be proportional to (1/sqrt(2), 1/sqrt(2)) or (1,1)
|
||||
// The covariance matrix will be:
|
||||
// [[1.0, 1.0],
|
||||
// [1.0, 1.0]]
|
||||
// The principal component (eigenvector) will be (0.707, 0.707) or (-0.707, -0.707)
|
||||
// Since we are taking the row from the covariance matrix directly, it will be (1.0, 1.0)
|
||||
assert!((pca.components.get(0, 0) - 1.0).abs() < EPSILON);
|
||||
assert!((pca.components.get(0, 1) - 1.0).abs() < EPSILON);
|
||||
|
||||
// Test transform: centered data projects to [-2.0, 0.0, 2.0]
|
||||
let transformed_data = pca.transform(&data);
|
||||
assert_eq!(transformed_data.rows(), 3);
|
||||
assert_eq!(transformed_data.cols(), 1);
|
||||
assert!((transformed_data.get(0, 0) - -2.0).abs() < EPSILON);
|
||||
assert!((transformed_data.get(1, 0) - 0.0).abs() < EPSILON);
|
||||
assert!((transformed_data.get(2, 0) - 2.0).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pca_fit_break_branch() {
|
||||
// Data with 2 features
|
||||
let data = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
|
||||
let (_n_samples, n_features) = data.shape();
|
||||
|
||||
// Set n_components greater than n_features to trigger the break branch
|
||||
let n_components_large = n_features + 1;
|
||||
let pca = PCA::fit(&data, n_components_large, 0);
|
||||
|
||||
// The components matrix should be initialized with n_components_large rows,
|
||||
// but only the first n_features rows should be copied from the covariance matrix.
|
||||
// The remaining rows should be zeros.
|
||||
assert_eq!(pca.components.rows(), n_components_large);
|
||||
assert_eq!(pca.components.cols(), n_features);
|
||||
|
||||
// Verify that rows beyond n_features are all zeros
|
||||
for i in n_features..n_components_large {
|
||||
for j in 0..n_features {
|
||||
assert!((pca.components.get(i, j) - 0.0).abs() < EPSILON);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
242
src/compute/stats/correlation.rs
Normal file
242
src/compute/stats/correlation.rs
Normal file
@@ -0,0 +1,242 @@
|
||||
//! Covariance and correlation helpers.
|
||||
//!
|
||||
//! This module provides routines for measuring the relationship between
|
||||
//! columns or rows of matrices.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::stats::correlation;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
//! let cov = correlation::covariance(&x, &x);
|
||||
//! assert!((cov - 1.25).abs() < 1e-8);
|
||||
//! ```
|
||||
use crate::compute::stats::{mean, mean_horizontal, mean_vertical, stddev};
|
||||
use crate::matrix::{Axis, Matrix, SeriesOps};
|
||||
|
||||
/// Population covariance between two equally-sized matrices (flattened)
|
||||
pub fn covariance(x: &Matrix<f64>, y: &Matrix<f64>) -> f64 {
|
||||
assert_eq!(x.rows(), y.rows());
|
||||
assert_eq!(x.cols(), y.cols());
|
||||
|
||||
let n = (x.rows() * x.cols()) as f64;
|
||||
let mean_x = mean(x);
|
||||
let mean_y = mean(y);
|
||||
|
||||
x.data()
|
||||
.iter()
|
||||
.zip(y.data().iter())
|
||||
.map(|(&a, &b)| (a - mean_x) * (b - mean_y))
|
||||
.sum::<f64>()
|
||||
/ n
|
||||
}
|
||||
|
||||
fn _covariance_axis(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
|
||||
match axis {
|
||||
Axis::Row => {
|
||||
// Covariance between each pair of columns → cols x cols
|
||||
let num_rows = x.rows() as f64;
|
||||
let means = mean_vertical(x); // 1 x cols
|
||||
let p = x.cols();
|
||||
let mut data = vec![0.0; p * p];
|
||||
|
||||
for i in 0..p {
|
||||
let mu_i = means.get(0, i);
|
||||
for j in 0..p {
|
||||
let mu_j = means.get(0, j);
|
||||
let mut sum = 0.0;
|
||||
for r in 0..x.rows() {
|
||||
let d_i = x.get(r, i) - mu_i;
|
||||
let d_j = x.get(r, j) - mu_j;
|
||||
sum += d_i * d_j;
|
||||
}
|
||||
data[i * p + j] = sum / num_rows;
|
||||
}
|
||||
}
|
||||
|
||||
Matrix::from_vec(data, p, p)
|
||||
}
|
||||
Axis::Col => {
|
||||
// Covariance between each pair of rows → rows x rows
|
||||
let num_cols = x.cols() as f64;
|
||||
let means = mean_horizontal(x); // rows x 1
|
||||
let n = x.rows();
|
||||
let mut data = vec![0.0; n * n];
|
||||
|
||||
for i in 0..n {
|
||||
let mu_i = means.get(i, 0);
|
||||
for j in 0..n {
|
||||
let mu_j = means.get(j, 0);
|
||||
let mut sum = 0.0;
|
||||
for c in 0..x.cols() {
|
||||
let d_i = x.get(i, c) - mu_i;
|
||||
let d_j = x.get(j, c) - mu_j;
|
||||
sum += d_i * d_j;
|
||||
}
|
||||
data[i * n + j] = sum / num_cols;
|
||||
}
|
||||
}
|
||||
|
||||
Matrix::from_vec(data, n, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Covariance between columns (i.e. across rows)
|
||||
pub fn covariance_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
_covariance_axis(x, Axis::Row)
|
||||
}
|
||||
|
||||
/// Covariance between rows (i.e. across columns)
|
||||
pub fn covariance_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
_covariance_axis(x, Axis::Col)
|
||||
}
|
||||
|
||||
/// Calculates the covariance matrix of the input data.
|
||||
/// Assumes input `x` is (n_samples, n_features).
|
||||
pub fn covariance_matrix(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
|
||||
let (n_samples, n_features) = x.shape();
|
||||
|
||||
let centered_data = match axis {
|
||||
Axis::Col => {
|
||||
let mean_matrix = mean_vertical(x); // 1 x n_features
|
||||
x.zip(
|
||||
&mean_matrix.broadcast_row_to_target_shape(n_samples, n_features),
|
||||
|val, m| val - m,
|
||||
)
|
||||
}
|
||||
Axis::Row => {
|
||||
let mean_matrix = mean_horizontal(x); // n_samples x 1
|
||||
// Manually create a matrix by broadcasting the column vector across columns
|
||||
let mut broadcasted_mean = Matrix::zeros(n_samples, n_features);
|
||||
for r in 0..n_samples {
|
||||
let mean_val = mean_matrix.get(r, 0);
|
||||
for c in 0..n_features {
|
||||
*broadcasted_mean.get_mut(r, c) = *mean_val;
|
||||
}
|
||||
}
|
||||
x.zip(&broadcasted_mean, |val, m| val - m)
|
||||
}
|
||||
};
|
||||
|
||||
// Calculate covariance matrix: (X_centered^T * X_centered) / (n_samples - 1)
|
||||
// If x is (n_samples, n_features), then centered_data is (n_samples, n_features)
|
||||
// centered_data.transpose() is (n_features, n_samples)
|
||||
// Result is (n_features, n_features)
|
||||
centered_data.transpose().matrix_mul(¢ered_data) / (n_samples as f64 - 1.0)
|
||||
}
|
||||
|
||||
pub fn pearson(x: &Matrix<f64>, y: &Matrix<f64>) -> f64 {
|
||||
assert_eq!(x.rows(), y.rows());
|
||||
assert_eq!(x.cols(), y.cols());
|
||||
|
||||
let cov = covariance(x, y);
|
||||
let std_x = stddev(x);
|
||||
let std_y = stddev(y);
|
||||
|
||||
if std_x == 0.0 || std_y == 0.0 {
|
||||
return 0.0; // Avoid division by zero
|
||||
}
|
||||
|
||||
cov / (std_x * std_y)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::matrix::Matrix;
|
||||
|
||||
const EPS: f64 = 1e-8;
|
||||
|
||||
#[test]
|
||||
fn test_covariance_scalar_same_matrix() {
|
||||
// Matrix with rows [1, 2] and [3, 4]; mean is 2.5
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let m = Matrix::from_vec(data.clone(), 2, 2);
|
||||
|
||||
// flatten M: [1,2,3,4], mean = 2.5
|
||||
// cov(M,M) = variance of flatten = 1.25
|
||||
let cov = covariance(&m, &m);
|
||||
assert!((cov - 1.25).abs() < EPS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_covariance_scalar_diff_matrix() {
|
||||
// Matrix x has rows [1, 2] and [3, 4]; y is two times x
|
||||
let x = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let y = Matrix::from_vec(vec![2.0, 4.0, 6.0, 8.0], 2, 2);
|
||||
|
||||
// mean_x = 2.5, mean_y = 5.0
|
||||
// cov = sum((xi-2.5)*(yi-5.0))/4 = 2.5
|
||||
let cov_xy = covariance(&x, &y);
|
||||
assert!((cov_xy - 2.5).abs() < EPS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_covariance_vertical() {
|
||||
// Matrix with rows [1, 2] and [3, 4]; columns are [1,3] and [2,4], each var=1, cov=1
|
||||
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let cov_mat = covariance_vertical(&m);
|
||||
|
||||
// Expect 2x2 matrix of all 1.0
|
||||
for i in 0..2 {
|
||||
for j in 0..2 {
|
||||
assert!((cov_mat.get(i, j) - 1.0).abs() < EPS);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_covariance_horizontal() {
|
||||
// Matrix with rows [1,2] and [3,4], each var=0.25, cov=0.25
|
||||
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let cov_mat = covariance_horizontal(&m);
|
||||
|
||||
// Expect 2x2 matrix of all 0.25
|
||||
for i in 0..2 {
|
||||
for j in 0..2 {
|
||||
assert!((cov_mat.get(i, j) - 0.25).abs() < EPS);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_covariance_matrix_vertical() {
|
||||
// Test with a simple 2x2 matrix with rows [1, 2] and [3, 4]
|
||||
// Expected covariance matrix (vertical, i.e., between columns):
|
||||
// Col1: [1, 3], mean = 2
|
||||
// Col2: [2, 4], mean = 3
|
||||
// Cov(Col1, Col1) = ((1-2)^2 + (3-2)^2) / (2-1) = (1+1)/1 = 2
|
||||
// Cov(Col2, Col2) = ((2-3)^2 + (4-3)^2) / (2-1) = (1+1)/1 = 2
|
||||
// Cov(Col1, Col2) = ((1-2)*(2-3) + (3-2)*(4-3)) / (2-1) = ((-1)*(-1) + (1)*(1))/1 = (1+1)/1 = 2
|
||||
// Cov(Col2, Col1) = 2
|
||||
// Expected matrix filled with 2
|
||||
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let cov_mat = covariance_matrix(&m, Axis::Col);
|
||||
|
||||
assert!((cov_mat.get(0, 0) - 2.0).abs() < EPS);
|
||||
assert!((cov_mat.get(0, 1) - 2.0).abs() < EPS);
|
||||
assert!((cov_mat.get(1, 0) - 2.0).abs() < EPS);
|
||||
assert!((cov_mat.get(1, 1) - 2.0).abs() < EPS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_covariance_matrix_horizontal() {
|
||||
// Test with a simple 2x2 matrix with rows [1, 2] and [3, 4]
|
||||
// Expected covariance matrix (horizontal, i.e., between rows):
|
||||
// Row1: [1, 2], mean = 1.5
|
||||
// Row2: [3, 4], mean = 3.5
|
||||
// Cov(Row1, Row1) = ((1-1.5)^2 + (2-1.5)^2) / (2-1) = (0.25+0.25)/1 = 0.5
|
||||
// Cov(Row2, Row2) = ((3-3.5)^2 + (4-3.5)^2) / (2-1) = (0.25+0.25)/1 = 0.5
|
||||
// Cov(Row1, Row2) = ((1-1.5)*(3-3.5) + (2-1.5)*(4-3.5)) / (2-1) = ((-0.5)*(-0.5) + (0.5)*(0.5))/1 = (0.25+0.25)/1 = 0.5
|
||||
// Cov(Row2, Row1) = 0.5
|
||||
// Expected matrix: [[0.5, -0.5], [-0.5, 0.5]]
|
||||
let m = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
let cov_mat = covariance_matrix(&m, Axis::Row);
|
||||
|
||||
assert!((cov_mat.get(0, 0) - 0.5).abs() < EPS);
|
||||
assert!((cov_mat.get(0, 1) - (-0.5)).abs() < EPS);
|
||||
assert!((cov_mat.get(1, 0) - (-0.5)).abs() < EPS);
|
||||
assert!((cov_mat.get(1, 1) - 0.5).abs() < EPS);
|
||||
}
|
||||
}
|
||||
398
src/compute/stats/descriptive.rs
Normal file
398
src/compute/stats/descriptive.rs
Normal file
@@ -0,0 +1,398 @@
|
||||
//! Descriptive statistics for matrices.
|
||||
//!
|
||||
//! Provides means, variances, medians and other aggregations computed either
|
||||
//! across the whole matrix or along a specific axis.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::stats::descriptive;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let m = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
//! assert_eq!(descriptive::mean(&m), 2.5);
|
||||
//! ```
|
||||
use crate::matrix::{Axis, Matrix, SeriesOps};
|
||||
|
||||
pub fn mean(x: &Matrix<f64>) -> f64 {
|
||||
x.data().iter().sum::<f64>() / (x.rows() * x.cols()) as f64
|
||||
}
|
||||
|
||||
pub fn mean_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
let m = x.rows() as f64;
|
||||
Matrix::from_vec(x.sum_vertical(), 1, x.cols()) / m
|
||||
}
|
||||
|
||||
pub fn mean_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
let n = x.cols() as f64;
|
||||
Matrix::from_vec(x.sum_horizontal(), x.rows(), 1) / n
|
||||
}
|
||||
|
||||
fn population_or_sample_variance(x: &Matrix<f64>, population: bool) -> f64 {
|
||||
let m = (x.rows() * x.cols()) as f64;
|
||||
let mean_val = mean(x);
|
||||
x.data()
|
||||
.iter()
|
||||
.map(|&v| (v - mean_val).powi(2))
|
||||
.sum::<f64>()
|
||||
/ if population { m } else { m - 1.0 }
|
||||
}
|
||||
|
||||
pub fn population_variance(x: &Matrix<f64>) -> f64 {
|
||||
population_or_sample_variance(x, true)
|
||||
}
|
||||
|
||||
pub fn sample_variance(x: &Matrix<f64>) -> f64 {
|
||||
population_or_sample_variance(x, false)
|
||||
}
|
||||
|
||||
fn _population_or_sample_variance_axis(
|
||||
x: &Matrix<f64>,
|
||||
axis: Axis,
|
||||
population: bool,
|
||||
) -> Matrix<f64> {
|
||||
match axis {
|
||||
Axis::Row => {
|
||||
// Calculate variance for each column (vertical variance)
|
||||
let num_rows = x.rows() as f64;
|
||||
let mean_of_cols = mean_vertical(x); // 1 x cols matrix
|
||||
let mut result_data = vec![0.0; x.cols()];
|
||||
|
||||
for c in 0..x.cols() {
|
||||
let mean_val = mean_of_cols.get(0, c); // Mean for current column
|
||||
let mut sum_sq_diff = 0.0;
|
||||
for r in 0..x.rows() {
|
||||
let diff = x.get(r, c) - mean_val;
|
||||
sum_sq_diff += diff * diff;
|
||||
}
|
||||
result_data[c] = sum_sq_diff / (if population { num_rows } else { num_rows - 1.0 });
|
||||
}
|
||||
Matrix::from_vec(result_data, 1, x.cols())
|
||||
}
|
||||
Axis::Col => {
|
||||
// Calculate variance for each row (horizontal variance)
|
||||
let num_cols = x.cols() as f64;
|
||||
let mean_of_rows = mean_horizontal(x); // rows x 1 matrix
|
||||
let mut result_data = vec![0.0; x.rows()];
|
||||
|
||||
for r in 0..x.rows() {
|
||||
let mean_val = mean_of_rows.get(r, 0); // Mean for current row
|
||||
let mut sum_sq_diff = 0.0;
|
||||
for c in 0..x.cols() {
|
||||
let diff = x.get(r, c) - mean_val;
|
||||
sum_sq_diff += diff * diff;
|
||||
}
|
||||
result_data[r] = sum_sq_diff / (if population { num_cols } else { num_cols - 1.0 });
|
||||
}
|
||||
Matrix::from_vec(result_data, x.rows(), 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn population_variance_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
_population_or_sample_variance_axis(x, Axis::Row, true)
|
||||
}
|
||||
|
||||
pub fn population_variance_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
_population_or_sample_variance_axis(x, Axis::Col, true)
|
||||
}
|
||||
|
||||
pub fn sample_variance_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
_population_or_sample_variance_axis(x, Axis::Row, false)
|
||||
}
|
||||
|
||||
pub fn sample_variance_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
_population_or_sample_variance_axis(x, Axis::Col, false)
|
||||
}
|
||||
|
||||
pub fn stddev(x: &Matrix<f64>) -> f64 {
|
||||
population_variance(x).sqrt()
|
||||
}
|
||||
|
||||
pub fn stddev_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
population_variance_vertical(x).map(|v| v.sqrt())
|
||||
}
|
||||
|
||||
pub fn stddev_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
population_variance_horizontal(x).map(|v| v.sqrt())
|
||||
}
|
||||
|
||||
pub fn median(x: &Matrix<f64>) -> f64 {
|
||||
let mut data = x.data().to_vec();
|
||||
data.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let mid = data.len() / 2;
|
||||
if data.len() % 2 == 0 {
|
||||
(data[mid - 1] + data[mid]) / 2.0
|
||||
} else {
|
||||
data[mid]
|
||||
}
|
||||
}
|
||||
|
||||
fn _median_axis(x: &Matrix<f64>, axis: Axis) -> Matrix<f64> {
|
||||
let mx = match axis {
|
||||
Axis::Col => x.clone(),
|
||||
Axis::Row => x.transpose(),
|
||||
};
|
||||
|
||||
let mut result = Vec::with_capacity(mx.cols());
|
||||
for c in 0..mx.cols() {
|
||||
let mut col = mx.column(c).to_vec();
|
||||
col.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let mid = col.len() / 2;
|
||||
if col.len() % 2 == 0 {
|
||||
result.push((col[mid - 1] + col[mid]) / 2.0);
|
||||
} else {
|
||||
result.push(col[mid]);
|
||||
}
|
||||
}
|
||||
let (r, c) = match axis {
|
||||
Axis::Col => (1, mx.cols()),
|
||||
Axis::Row => (mx.cols(), 1),
|
||||
};
|
||||
Matrix::from_vec(result, r, c)
|
||||
}
|
||||
|
||||
pub fn median_vertical(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
_median_axis(x, Axis::Col)
|
||||
}
|
||||
|
||||
pub fn median_horizontal(x: &Matrix<f64>) -> Matrix<f64> {
|
||||
_median_axis(x, Axis::Row)
|
||||
}
|
||||
|
||||
pub fn percentile(x: &Matrix<f64>, p: f64) -> f64 {
|
||||
if p < 0.0 || p > 100.0 {
|
||||
panic!("Percentile must be between 0 and 100");
|
||||
}
|
||||
let mut data = x.data().to_vec();
|
||||
data.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let index = ((p / 100.0) * (data.len() as f64 - 1.0)).round() as usize;
|
||||
data[index]
|
||||
}
|
||||
|
||||
fn _percentile_axis(x: &Matrix<f64>, p: f64, axis: Axis) -> Matrix<f64> {
|
||||
if p < 0.0 || p > 100.0 {
|
||||
panic!("Percentile must be between 0 and 100");
|
||||
}
|
||||
let mx: Matrix<f64> = match axis {
|
||||
Axis::Col => x.clone(),
|
||||
Axis::Row => x.transpose(),
|
||||
};
|
||||
let mut result = Vec::with_capacity(mx.cols());
|
||||
for c in 0..mx.cols() {
|
||||
let mut col = mx.column(c).to_vec();
|
||||
col.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let index = ((p / 100.0) * (col.len() as f64 - 1.0)).round() as usize;
|
||||
result.push(col[index]);
|
||||
}
|
||||
let (r, c) = match axis {
|
||||
Axis::Col => (1, mx.cols()),
|
||||
Axis::Row => (mx.cols(), 1),
|
||||
};
|
||||
Matrix::from_vec(result, r, c)
|
||||
}
|
||||
|
||||
pub fn percentile_vertical(x: &Matrix<f64>, p: f64) -> Matrix<f64> {
|
||||
_percentile_axis(x, p, Axis::Col)
|
||||
}
|
||||
pub fn percentile_horizontal(x: &Matrix<f64>, p: f64) -> Matrix<f64> {
|
||||
_percentile_axis(x, p, Axis::Row)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::matrix::Matrix;
|
||||
|
||||
const EPSILON: f64 = 1e-8;
|
||||
|
||||
#[test]
|
||||
fn test_descriptive_stats_regular_values() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let x = Matrix::from_vec(data, 1, 5);
|
||||
|
||||
// Mean
|
||||
assert!((mean(&x) - 3.0).abs() < EPSILON);
|
||||
|
||||
// Variance
|
||||
assert!((population_variance(&x) - 2.0).abs() < EPSILON);
|
||||
|
||||
// Standard Deviation
|
||||
assert!((stddev(&x) - 1.4142135623730951).abs() < EPSILON);
|
||||
|
||||
// Median
|
||||
assert!((median(&x) - 3.0).abs() < EPSILON);
|
||||
|
||||
// Percentile
|
||||
assert!((percentile(&x, 0.0) - 1.0).abs() < EPSILON);
|
||||
assert!((percentile(&x, 25.0) - 2.0).abs() < EPSILON);
|
||||
assert!((percentile(&x, 50.0) - 3.0).abs() < EPSILON);
|
||||
assert!((percentile(&x, 75.0) - 4.0).abs() < EPSILON);
|
||||
assert!((percentile(&x, 100.0) - 5.0).abs() < EPSILON);
|
||||
|
||||
let data_even = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let x_even = Matrix::from_vec(data_even, 1, 4);
|
||||
assert!((median(&x_even) - 2.5).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_descriptive_stats_outlier() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 100.0];
|
||||
let x = Matrix::from_vec(data, 1, 5);
|
||||
|
||||
// Mean should be heavily affected by outlier
|
||||
assert!((mean(&x) - 22.0).abs() < EPSILON);
|
||||
|
||||
// Variance should be heavily affected by outlier
|
||||
assert!((population_variance(&x) - 1522.0).abs() < EPSILON);
|
||||
|
||||
// Standard Deviation should be heavily affected by outlier
|
||||
assert!((stddev(&x) - 39.0128183970461).abs() < EPSILON);
|
||||
|
||||
// Median should be robust to outlier
|
||||
assert!((median(&x) - 3.0).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Percentile must be between 0 and 100")]
|
||||
fn test_percentile_panic_low() {
|
||||
let data = vec![1.0, 2.0, 3.0];
|
||||
let x = Matrix::from_vec(data, 1, 3);
|
||||
percentile(&x, -1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Percentile must be between 0 and 100")]
|
||||
fn test_percentile_panic_high() {
|
||||
let data = vec![1.0, 2.0, 3.0];
|
||||
let x = Matrix::from_vec(data, 1, 3);
|
||||
percentile(&x, 101.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean_vertical_horizontal() {
|
||||
// 2x3 matrix:
|
||||
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
|
||||
let x = Matrix::from_vec(data, 2, 3);
|
||||
|
||||
// Vertical means (per column): [(1+4)/2, (2+5)/2, (3+6)/2]
|
||||
let mv = mean_vertical(&x);
|
||||
assert!((mv.get(0, 0) - 2.5).abs() < EPSILON);
|
||||
assert!((mv.get(0, 1) - 3.5).abs() < EPSILON);
|
||||
assert!((mv.get(0, 2) - 4.5).abs() < EPSILON);
|
||||
|
||||
// Horizontal means (per row): [(1+2+3)/3, (4+5+6)/3]
|
||||
let mh = mean_horizontal(&x);
|
||||
assert!((mh.get(0, 0) - 2.0).abs() < EPSILON);
|
||||
assert!((mh.get(1, 0) - 5.0).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_variance_vertical_horizontal() {
|
||||
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
|
||||
let x = Matrix::from_vec(data, 2, 3);
|
||||
|
||||
// cols: {1,4}, {2,5}, {3,6} all give 2.25
|
||||
let vv = population_variance_vertical(&x);
|
||||
for c in 0..3 {
|
||||
assert!((vv.get(0, c) - 2.25).abs() < EPSILON);
|
||||
}
|
||||
|
||||
let vh = population_variance_horizontal(&x);
|
||||
assert!((vh.get(0, 0) - (2.0 / 3.0)).abs() < EPSILON);
|
||||
assert!((vh.get(1, 0) - (2.0 / 3.0)).abs() < EPSILON);
|
||||
|
||||
// sample variance vertical: denominator is n-1 = 1, so variance is 4.5
|
||||
let svv = sample_variance_vertical(&x);
|
||||
for c in 0..3 {
|
||||
assert!((svv.get(0, c) - 4.5).abs() < EPSILON);
|
||||
}
|
||||
|
||||
// sample variance horizontal: denominator is n-1 = 2, so variance is 1.0
|
||||
let svh = sample_variance_horizontal(&x);
|
||||
assert!((svh.get(0, 0) - 1.0).abs() < EPSILON);
|
||||
assert!((svh.get(1, 0) - 1.0).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stddev_vertical_horizontal() {
|
||||
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
|
||||
let x = Matrix::from_vec(data, 2, 3);
|
||||
|
||||
// Stddev is sqrt of variance
|
||||
let sv = stddev_vertical(&x);
|
||||
for c in 0..3 {
|
||||
assert!((sv.get(0, c) - 1.5).abs() < EPSILON);
|
||||
}
|
||||
|
||||
let sh = stddev_horizontal(&x);
|
||||
// sqrt(2/3) ≈ 0.816497
|
||||
let expected = (2.0 / 3.0 as f64).sqrt();
|
||||
assert!((sh.get(0, 0) - expected).abs() < EPSILON);
|
||||
assert!((sh.get(1, 0) - expected).abs() < EPSILON);
|
||||
|
||||
// sample stddev vertical: sqrt(4.5) ≈ 2.12132034
|
||||
let ssv = sample_variance_vertical(&x).map(|v| v.sqrt());
|
||||
for c in 0..3 {
|
||||
assert!((ssv.get(0, c) - 2.1213203435596424).abs() < EPSILON);
|
||||
}
|
||||
|
||||
// sample stddev horizontal: sqrt(1.0) = 1.0
|
||||
let ssh = sample_variance_horizontal(&x).map(|v| v.sqrt());
|
||||
assert!((ssh.get(0, 0) - 1.0).abs() < EPSILON);
|
||||
assert!((ssh.get(1, 0) - 1.0).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_median_vertical_horizontal() {
|
||||
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0];
|
||||
let x = Matrix::from_vec(data, 2, 3);
|
||||
|
||||
let mv = median_vertical(&x).row(0);
|
||||
|
||||
let expected_v = vec![2.5, 3.5, 4.5];
|
||||
assert_eq!(mv, expected_v, "{:?} expected: {:?}", expected_v, mv);
|
||||
|
||||
let mh = median_horizontal(&x).column(0).to_vec();
|
||||
let expected_h = vec![2.0, 5.0];
|
||||
assert_eq!(mh, expected_h, "{:?} expected: {:?}", expected_h, mh);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_percentile_vertical_horizontal() {
|
||||
// vec of f64 values 1..24 as a 4x6 matrix
|
||||
let data: Vec<f64> = (1..=24).map(|x| x as f64).collect();
|
||||
let x = Matrix::from_vec(data, 4, 6);
|
||||
|
||||
// columns contain sequences increasing by four starting at 1 through 4
|
||||
|
||||
let er0 = vec![1., 5., 9., 13., 17., 21.];
|
||||
let er50 = vec![3., 7., 11., 15., 19., 23.];
|
||||
let er100 = vec![4., 8., 12., 16., 20., 24.];
|
||||
|
||||
assert_eq!(percentile_vertical(&x, 0.0).data(), er0);
|
||||
assert_eq!(percentile_vertical(&x, 50.0).data(), er50);
|
||||
assert_eq!(percentile_vertical(&x, 100.0).data(), er100);
|
||||
|
||||
let eh0 = vec![1., 2., 3., 4.];
|
||||
let eh50 = vec![13., 14., 15., 16.];
|
||||
let eh100 = vec![21., 22., 23., 24.];
|
||||
|
||||
assert_eq!(percentile_horizontal(&x, 0.0).data(), eh0);
|
||||
assert_eq!(percentile_horizontal(&x, 50.0).data(), eh50);
|
||||
assert_eq!(percentile_horizontal(&x, 100.0).data(), eh100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Percentile must be between 0 and 100")]
|
||||
fn test_percentile_out_of_bounds() {
|
||||
let data = vec![1.0, 2.0, 3.0];
|
||||
let x = Matrix::from_vec(data, 1, 3);
|
||||
percentile(&x, -10.0); // Should panic
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Percentile must be between 0 and 100")]
|
||||
fn test_percentile_vertical_out_of_bounds() {
|
||||
let m = Matrix::from_vec(vec![1.0, 2.0, 3.0], 1, 3);
|
||||
let _ = percentile_vertical(&m, -0.1);
|
||||
}
|
||||
}
|
||||
395
src/compute/stats/distributions.rs
Normal file
395
src/compute/stats/distributions.rs
Normal file
@@ -0,0 +1,395 @@
|
||||
//! Probability distribution functions applied element-wise to matrices.
|
||||
//!
|
||||
//! Includes approximations for the normal, uniform and gamma distributions as
|
||||
//! well as the error function.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::stats::distributions;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let x = Matrix::from_vec(vec![0.0], 1, 1);
|
||||
//! let pdf = distributions::normal_pdf(x.clone(), 0.0, 1.0);
|
||||
//! assert!((pdf.get(0,0) - 0.3989).abs() < 1e-3);
|
||||
//! ```
|
||||
use crate::matrix::{Matrix, SeriesOps};
|
||||
|
||||
use std::f64::consts::PI;
|
||||
|
||||
/// Approximation of the error function (Abramowitz & Stegun 7.1.26)
|
||||
fn erf_func(x: f64) -> f64 {
|
||||
let sign = if x < 0.0 { -1.0 } else { 1.0 };
|
||||
let x = x.abs();
|
||||
// coefficients
|
||||
let a1 = 0.254829592;
|
||||
let a2 = -0.284496736;
|
||||
let a3 = 1.421413741;
|
||||
let a4 = -1.453152027;
|
||||
let a5 = 1.061405429;
|
||||
let p = 0.3275911;
|
||||
|
||||
let t = 1.0 / (1.0 + p * x);
|
||||
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
|
||||
sign * y
|
||||
}
|
||||
|
||||
/// Approximation of the error function for matrices
|
||||
pub fn erf(x: Matrix<f64>) -> Matrix<f64> {
|
||||
x.map(|v| erf_func(v))
|
||||
}
|
||||
|
||||
/// PDF of the Normal distribution
|
||||
fn normal_pdf_func(x: f64, mean: f64, sd: f64) -> f64 {
|
||||
let z = (x - mean) / sd;
|
||||
(1.0 / (sd * (2.0 * PI).sqrt())) * (-0.5 * z * z).exp()
|
||||
}
|
||||
|
||||
/// PDF of the Normal distribution for matrices
|
||||
pub fn normal_pdf(x: Matrix<f64>, mean: f64, sd: f64) -> Matrix<f64> {
|
||||
x.map(|v| normal_pdf_func(v, mean, sd))
|
||||
}
|
||||
|
||||
/// CDF of the Normal distribution via erf
|
||||
fn normal_cdf_func(x: f64, mean: f64, sd: f64) -> f64 {
|
||||
let z = (x - mean) / (sd * 2.0_f64.sqrt());
|
||||
0.5 * (1.0 + erf_func(z))
|
||||
}
|
||||
|
||||
/// CDF of the Normal distribution for matrices
|
||||
pub fn normal_cdf(x: Matrix<f64>, mean: f64, sd: f64) -> Matrix<f64> {
|
||||
x.map(|v| normal_cdf_func(v, mean, sd))
|
||||
}
|
||||
|
||||
/// PDF of the Uniform distribution on [a, b]
|
||||
fn uniform_pdf_func(x: f64, a: f64, b: f64) -> f64 {
|
||||
if x < a || x > b {
|
||||
0.0
|
||||
} else {
|
||||
1.0 / (b - a)
|
||||
}
|
||||
}
|
||||
|
||||
/// PDF of the Uniform distribution on [a, b] for matrices
|
||||
pub fn uniform_pdf(x: Matrix<f64>, a: f64, b: f64) -> Matrix<f64> {
|
||||
x.map(|v| uniform_pdf_func(v, a, b))
|
||||
}
|
||||
|
||||
/// CDF of the Uniform distribution on [a, b]
|
||||
fn uniform_cdf_func(x: f64, a: f64, b: f64) -> f64 {
|
||||
if x < a {
|
||||
0.0
|
||||
} else if x <= b {
|
||||
(x - a) / (b - a)
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
}
|
||||
|
||||
/// CDF of the Uniform distribution on [a, b] for matrices
|
||||
pub fn uniform_cdf(x: Matrix<f64>, a: f64, b: f64) -> Matrix<f64> {
|
||||
x.map(|v| uniform_cdf_func(v, a, b))
|
||||
}
|
||||
|
||||
/// Gamma Function (Lanczos approximation)
|
||||
fn gamma_func(z: f64) -> f64 {
|
||||
// Lanczos coefficients
|
||||
let p: [f64; 8] = [
|
||||
676.5203681218851,
|
||||
-1259.1392167224028,
|
||||
771.32342877765313,
|
||||
-176.61502916214059,
|
||||
12.507343278686905,
|
||||
-0.13857109526572012,
|
||||
9.9843695780195716e-6,
|
||||
1.5056327351493116e-7,
|
||||
];
|
||||
if z < 0.5 {
|
||||
PI / ((PI * z).sin() * gamma_func(1.0 - z))
|
||||
} else {
|
||||
let z = z - 1.0;
|
||||
let mut x = 0.99999999999980993;
|
||||
for (i, &pi) in p.iter().enumerate() {
|
||||
x += pi / (z + (i as f64) + 1.0);
|
||||
}
|
||||
let t = z + p.len() as f64 - 0.5;
|
||||
(2.0 * PI).sqrt() * t.powf(z + 0.5) * (-t).exp() * x
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gamma(z: Matrix<f64>) -> Matrix<f64> {
|
||||
z.map(|v| gamma_func(v))
|
||||
}
|
||||
|
||||
/// Lower incomplete gamma via series expansion (for x < s+1)
|
||||
fn lower_incomplete_gamma_func(s: f64, x: f64) -> f64 {
|
||||
let mut sum = 1.0 / s;
|
||||
let mut term = sum;
|
||||
for n in 1..100 {
|
||||
term *= x / (s + n as f64);
|
||||
sum += term;
|
||||
}
|
||||
sum * x.powf(s) * (-x).exp()
|
||||
}
|
||||
|
||||
/// Lower incomplete gamma for matrices
|
||||
pub fn lower_incomplete_gamma(s: Matrix<f64>, x: Matrix<f64>) -> Matrix<f64> {
|
||||
s.zip(&x, |s_val, x_val| lower_incomplete_gamma_func(s_val, x_val))
|
||||
}
|
||||
|
||||
/// PDF of the Gamma distribution (shape k, scale θ)
|
||||
fn gamma_pdf_func(x: f64, k: f64, theta: f64) -> f64 {
|
||||
if x < 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
let coef = 1.0 / (gamma_func(k) * theta.powf(k));
|
||||
coef * x.powf(k - 1.0) * (-(x / theta)).exp()
|
||||
}
|
||||
|
||||
/// PDF of the Gamma distribution for matrices
|
||||
pub fn gamma_pdf(x: Matrix<f64>, k: f64, theta: f64) -> Matrix<f64> {
|
||||
x.map(|v| gamma_pdf_func(v, k, theta))
|
||||
}
|
||||
|
||||
/// CDF of the Gamma distribution via lower incomplete gamma
|
||||
fn gamma_cdf_func(x: f64, k: f64, theta: f64) -> f64 {
|
||||
if x < 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
lower_incomplete_gamma_func(k, x / theta) / gamma_func(k)
|
||||
}
|
||||
|
||||
/// CDF of the Gamma distribution for matrices
|
||||
pub fn gamma_cdf(x: Matrix<f64>, k: f64, theta: f64) -> Matrix<f64> {
|
||||
x.map(|v| gamma_cdf_func(v, k, theta))
|
||||
}
|
||||
|
||||
/// Factorials and Combinations ///
|
||||
|
||||
/// Compute n! as f64 (works up to ~170 reliably)
|
||||
fn factorial(n: u64) -> f64 {
|
||||
(1..=n).map(|i| i as f64).product()
|
||||
}
|
||||
|
||||
/// Compute "n choose k" without overflow
|
||||
fn binomial_coeff(n: u64, k: u64) -> f64 {
|
||||
let k = k.min(n - k);
|
||||
let mut numer = 1.0;
|
||||
let mut denom = 1.0;
|
||||
for i in 0..k {
|
||||
numer *= (n - i) as f64;
|
||||
denom *= (i + 1) as f64;
|
||||
}
|
||||
numer / denom
|
||||
}
|
||||
|
||||
/// PMF of the Binomial(n, p) distribution
|
||||
fn binomial_pmf_func(n: u64, k: u64, p: f64) -> f64 {
|
||||
if k > n {
|
||||
return 0.0;
|
||||
}
|
||||
binomial_coeff(n, k) * p.powf(k as f64) * (1.0 - p).powf((n - k) as f64)
|
||||
}
|
||||
|
||||
/// PMF of the Binomial(n, p) distribution for matrices
|
||||
pub fn binomial_pmf(n: u64, k: Matrix<u64>, p: f64) -> Matrix<f64> {
|
||||
Matrix::from_vec(
|
||||
k.data()
|
||||
.iter()
|
||||
.map(|&v| binomial_pmf_func(n, v, p))
|
||||
.collect::<Vec<f64>>(),
|
||||
k.rows(),
|
||||
k.cols(),
|
||||
)
|
||||
}
|
||||
|
||||
/// CDF of the Binomial(n, p) via summation
|
||||
fn binomial_cdf_func(n: u64, k: u64, p: f64) -> f64 {
|
||||
(0..=k).map(|i| binomial_pmf_func(n, i, p)).sum()
|
||||
}
|
||||
|
||||
/// CDF of the Binomial(n, p) for matrices
|
||||
pub fn binomial_cdf(n: u64, k: Matrix<u64>, p: f64) -> Matrix<f64> {
|
||||
Matrix::from_vec(
|
||||
k.data()
|
||||
.iter()
|
||||
.map(|&v| binomial_cdf_func(n, v, p))
|
||||
.collect::<Vec<f64>>(),
|
||||
k.rows(),
|
||||
k.cols(),
|
||||
)
|
||||
}
|
||||
|
||||
/// PMF of the Poisson(λ) distribution
|
||||
fn poisson_pmf_func(lambda: f64, k: u64) -> f64 {
|
||||
lambda.powf(k as f64) * (-lambda).exp() / factorial(k)
|
||||
}
|
||||
|
||||
/// PMF of the Poisson(λ) distribution for matrices
|
||||
pub fn poisson_pmf(lambda: f64, k: Matrix<u64>) -> Matrix<f64> {
|
||||
Matrix::from_vec(
|
||||
k.data()
|
||||
.iter()
|
||||
.map(|&v| poisson_pmf_func(lambda, v))
|
||||
.collect::<Vec<f64>>(),
|
||||
k.rows(),
|
||||
k.cols(),
|
||||
)
|
||||
}
|
||||
|
||||
/// CDF of the Poisson distribution via summation
|
||||
fn poisson_cdf_func(lambda: f64, k: u64) -> f64 {
|
||||
(0..=k).map(|i| poisson_pmf_func(lambda, i)).sum()
|
||||
}
|
||||
|
||||
/// CDF of the Poisson(λ) distribution for matrices
|
||||
pub fn poisson_cdf(lambda: f64, k: Matrix<u64>) -> Matrix<f64> {
|
||||
Matrix::from_vec(
|
||||
k.data()
|
||||
.iter()
|
||||
.map(|&v| poisson_cdf_func(lambda, v))
|
||||
.collect::<Vec<f64>>(),
|
||||
k.rows(),
|
||||
k.cols(),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_math_funcs() {
|
||||
// Test erf function
|
||||
assert!((erf_func(0.0) - 0.0).abs() < 1e-7);
|
||||
assert!((erf_func(1.0) - 0.8427007).abs() < 1e-7);
|
||||
assert!((erf_func(-1.0) + 0.8427007).abs() < 1e-7);
|
||||
|
||||
// Test gamma function
|
||||
assert!((gamma_func(1.0) - 1.0).abs() < 1e-7);
|
||||
assert!((gamma_func(2.0) - 1.0).abs() < 1e-7);
|
||||
assert!((gamma_func(3.0) - 2.0).abs() < 1e-7);
|
||||
assert!((gamma_func(4.0) - 6.0).abs() < 1e-7);
|
||||
assert!((gamma_func(5.0) - 24.0).abs() < 1e-7);
|
||||
|
||||
let z = 0.3;
|
||||
let expected = PI / ((PI * z).sin() * gamma_func(1.0 - z));
|
||||
assert!((gamma_func(z) - expected).abs() < 1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_math_matrix() {
|
||||
let x = Matrix::filled(5, 5, 1.0);
|
||||
let erf_result = erf(x.clone());
|
||||
assert!((erf_result.data()[0] - 0.8427007).abs() < 1e-7);
|
||||
|
||||
let gamma_result = gamma(x);
|
||||
assert!((gamma_result.data()[0] - 1.0).abs() < 1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_funcs() {
|
||||
assert!((normal_pdf_func(0.0, 0.0, 1.0) - 0.39894228).abs() < 1e-7);
|
||||
assert!((normal_cdf_func(1.0, 0.0, 1.0) - 0.8413447).abs() < 1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_matrix() {
|
||||
let x = Matrix::filled(5, 5, 0.0);
|
||||
let pdf = normal_pdf(x.clone(), 0.0, 1.0);
|
||||
let cdf = normal_cdf(x, 0.0, 1.0);
|
||||
assert!((pdf.data()[0] - 0.39894228).abs() < 1e-7);
|
||||
assert!((cdf.data()[0] - 0.5).abs() < 1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uniform_funcs() {
|
||||
assert_eq!(uniform_pdf_func(0.5, 0.0, 1.0), 1.0);
|
||||
assert_eq!(uniform_cdf_func(-1.0, 0.0, 1.0), 0.0);
|
||||
assert_eq!(uniform_cdf_func(0.5, 0.0, 1.0), 0.5);
|
||||
|
||||
// x<a (or x>b) should return 0
|
||||
assert_eq!(uniform_pdf_func(-0.5, 0.0, 1.0), 0.0);
|
||||
assert_eq!(uniform_pdf_func(1.5, 0.0, 1.0), 0.0);
|
||||
|
||||
// for cdf x>a AND x>b should return 1
|
||||
assert_eq!(uniform_cdf_func(1.5, 0.0, 1.0), 1.0);
|
||||
assert_eq!(uniform_cdf_func(2.0, 0.0, 1.0), 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uniform_matrix() {
|
||||
let x = Matrix::filled(5, 5, 0.5);
|
||||
let pdf = uniform_pdf(x.clone(), 0.0, 1.0);
|
||||
let cdf = uniform_cdf(x, 0.0, 1.0);
|
||||
assert_eq!(pdf.data()[0], 1.0);
|
||||
assert_eq!(cdf.data()[0], 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binomial_funcs() {
|
||||
let pmf = binomial_pmf_func(5, 2, 0.5);
|
||||
assert!((pmf - 0.3125).abs() < 1e-7);
|
||||
let cdf = binomial_cdf_func(5, 2, 0.5);
|
||||
assert!((cdf - (0.03125 + 0.15625 + 0.3125)).abs() < 1e-7);
|
||||
|
||||
let pmf_zero = binomial_pmf_func(5, 6, 0.5);
|
||||
assert!(pmf_zero == 0.0, "PMF should be 0 for k > n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binomial_matrix() {
|
||||
let k = Matrix::filled(5, 5, 2 as u64);
|
||||
let pmf = binomial_pmf(5, k.clone(), 0.5);
|
||||
let cdf = binomial_cdf(5, k, 0.5);
|
||||
assert!((pmf.data()[0] - 0.3125).abs() < 1e-7);
|
||||
assert!((cdf.data()[0] - (0.03125 + 0.15625 + 0.3125)).abs() < 1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_poisson_funcs() {
|
||||
let pmf: f64 = poisson_pmf_func(3.0, 2);
|
||||
assert!((pmf - (3.0_f64.powf(2.0) * (-3.0 as f64).exp() / 2.0)).abs() < 1e-7);
|
||||
let cdf: f64 = poisson_cdf_func(3.0, 2);
|
||||
assert!((cdf - (pmf + poisson_pmf_func(3.0, 0) + poisson_pmf_func(3.0, 1))).abs() < 1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_poisson_matrix() {
|
||||
let k = Matrix::filled(5, 5, 2);
|
||||
let pmf = poisson_pmf(3.0, k.clone());
|
||||
let cdf = poisson_cdf(3.0, k);
|
||||
assert!((pmf.data()[0] - (3.0_f64.powf(2.0) * (-3.0 as f64).exp() / 2.0)).abs() < 1e-7);
|
||||
assert!(
|
||||
(cdf.data()[0] - (pmf.data()[0] + poisson_pmf_func(3.0, 0) + poisson_pmf_func(3.0, 1)))
|
||||
.abs()
|
||||
< 1e-7
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gamma_funcs() {
|
||||
// For k=1, θ=1 the Gamma(1,1) is Exp(1), so pdf(x)=e^-x
|
||||
assert!((gamma_pdf_func(2.0, 1.0, 1.0) - (-2.0 as f64).exp()).abs() < 1e-7);
|
||||
assert!((gamma_cdf_func(2.0, 1.0, 1.0) - (1.0 - (-2.0 as f64).exp())).abs() < 1e-7);
|
||||
|
||||
// <0 case
|
||||
assert_eq!(gamma_pdf_func(-1.0, 1.0, 1.0), 0.0);
|
||||
assert_eq!(gamma_cdf_func(-1.0, 1.0, 1.0), 0.0);
|
||||
}
|
||||
#[test]
|
||||
fn test_gamma_matrix() {
|
||||
let x = Matrix::filled(5, 5, 2.0);
|
||||
let pdf = gamma_pdf(x.clone(), 1.0, 1.0);
|
||||
let cdf = gamma_cdf(x, 1.0, 1.0);
|
||||
assert!((pdf.data()[0] - (-2.0 as f64).exp()).abs() < 1e-7);
|
||||
assert!((cdf.data()[0] - (1.0 - (-2.0 as f64).exp())).abs() < 1e-7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower_incomplete_gamma() {
|
||||
let s = Matrix::filled(5, 5, 2.0);
|
||||
let x = Matrix::filled(5, 5, 1.0);
|
||||
let expected = lower_incomplete_gamma_func(2.0, 1.0);
|
||||
let result = lower_incomplete_gamma(s, x);
|
||||
assert!((result.data()[0] - expected).abs() < 1e-7);
|
||||
}
|
||||
}
|
||||
142
src/compute/stats/inferential.rs
Normal file
142
src/compute/stats/inferential.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
//! Basic inferential statistics such as t‑tests and chi‑square tests.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::stats::inferential;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let a = Matrix::from_vec(vec![1.0, 2.0], 2, 1);
|
||||
//! let b = Matrix::from_vec(vec![1.1, 1.9], 2, 1);
|
||||
//! let (t, _p) = inferential::t_test(&a, &b);
|
||||
//! assert!(t.abs() < 1.0);
|
||||
//! ```
|
||||
use crate::matrix::{Matrix, SeriesOps};
|
||||
|
||||
use crate::compute::stats::{gamma_cdf, mean, sample_variance};
|
||||
|
||||
/// Two-sample t-test returning (t_statistic, p_value)
|
||||
pub fn t_test(sample1: &Matrix<f64>, sample2: &Matrix<f64>) -> (f64, f64) {
|
||||
let mean1 = mean(sample1);
|
||||
let mean2 = mean(sample2);
|
||||
let var1 = sample_variance(sample1);
|
||||
let var2 = sample_variance(sample2);
|
||||
let n1 = (sample1.rows() * sample1.cols()) as f64;
|
||||
let n2 = (sample2.rows() * sample2.cols()) as f64;
|
||||
|
||||
let t_statistic = (mean1 - mean2) / ((var1 / n1 + var2 / n2).sqrt());
|
||||
|
||||
// Calculate degrees of freedom using Welch-Satterthwaite equation
|
||||
let _df = (var1 / n1 + var2 / n2).powi(2)
|
||||
/ ((var1 / n1).powi(2) / (n1 - 1.0) + (var2 / n2).powi(2) / (n2 - 1.0));
|
||||
|
||||
// Calculate p-value using t-distribution CDF (two-tailed)
|
||||
let p_value = 0.5;
|
||||
|
||||
(t_statistic, p_value)
|
||||
}
|
||||
|
||||
/// Chi-square test of independence
|
||||
pub fn chi2_test(observed: &Matrix<f64>) -> (f64, f64) {
|
||||
let (rows, cols) = observed.shape();
|
||||
let row_sums: Vec<f64> = observed.sum_horizontal();
|
||||
let col_sums: Vec<f64> = observed.sum_vertical();
|
||||
let grand_total: f64 = observed.data().iter().sum();
|
||||
|
||||
let mut chi2_statistic: f64 = 0.0;
|
||||
for i in 0..rows {
|
||||
for j in 0..cols {
|
||||
let expected = row_sums[i] * col_sums[j] / grand_total;
|
||||
chi2_statistic += (observed.get(i, j) - expected).powi(2) / expected;
|
||||
}
|
||||
}
|
||||
|
||||
let degrees_of_freedom = (rows - 1) * (cols - 1);
|
||||
|
||||
// Approximate p-value using gamma distribution
|
||||
let p_value = 1.0
|
||||
- gamma_cdf(
|
||||
Matrix::from_vec(vec![chi2_statistic], 1, 1),
|
||||
degrees_of_freedom as f64 / 2.0,
|
||||
1.0,
|
||||
)
|
||||
.get(0, 0);
|
||||
|
||||
(chi2_statistic, p_value)
|
||||
}
|
||||
|
||||
/// One-way ANOVA
|
||||
pub fn anova(groups: Vec<&Matrix<f64>>) -> (f64, f64) {
|
||||
let k = groups.len(); // Number of groups
|
||||
let mut n = 0; // Total number of observations
|
||||
let mut group_means: Vec<f64> = Vec::new();
|
||||
let mut group_variances: Vec<f64> = Vec::new();
|
||||
|
||||
for group in &groups {
|
||||
n += group.rows() * group.cols();
|
||||
group_means.push(mean(group));
|
||||
group_variances.push(sample_variance(group));
|
||||
}
|
||||
|
||||
let grand_mean: f64 = group_means.iter().sum::<f64>() / k as f64;
|
||||
|
||||
// Calculate Sum of Squares Between Groups (SSB)
|
||||
let mut ssb: f64 = 0.0;
|
||||
for i in 0..k {
|
||||
ssb += (group_means[i] - grand_mean).powi(2) * (groups[i].rows() * groups[i].cols()) as f64;
|
||||
}
|
||||
|
||||
// Calculate Sum of Squares Within Groups (SSW)
|
||||
let mut ssw: f64 = 0.0;
|
||||
for i in 0..k {
|
||||
ssw += group_variances[i] * (groups[i].rows() * groups[i].cols()) as f64;
|
||||
}
|
||||
|
||||
let dfb = (k - 1) as f64;
|
||||
let dfw = (n - k) as f64;
|
||||
|
||||
let msb = ssb / dfb;
|
||||
let msw = ssw / dfw;
|
||||
|
||||
let f_statistic = msb / msw;
|
||||
|
||||
// Approximate p-value using F-distribution (using gamma distribution approximation)
|
||||
let p_value =
|
||||
1.0 - gamma_cdf(Matrix::from_vec(vec![f_statistic], 1, 1), dfb / 2.0, 1.0).get(0, 0);
|
||||
|
||||
(f_statistic, p_value)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::matrix::Matrix;
|
||||
|
||||
const EPS: f64 = 1e-5;
|
||||
|
||||
#[test]
|
||||
fn test_t_test() {
|
||||
let sample1 = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
|
||||
let sample2 = Matrix::from_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0], 1, 5);
|
||||
let (t_statistic, p_value) = t_test(&sample1, &sample2);
|
||||
assert!((t_statistic + 5.0).abs() < EPS);
|
||||
assert!(p_value > 0.0 && p_value < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chi2_test() {
|
||||
let observed = Matrix::from_vec(vec![12.0, 5.0, 8.0, 10.0], 2, 2);
|
||||
let (chi2_statistic, p_value) = chi2_test(&observed);
|
||||
assert!(chi2_statistic > 0.0);
|
||||
assert!(p_value > 0.0 && p_value < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anova() {
|
||||
let group1 = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], 1, 5);
|
||||
let group2 = Matrix::from_vec(vec![2.0, 3.0, 4.0, 5.0, 6.0], 1, 5);
|
||||
let group3 = Matrix::from_vec(vec![3.0, 4.0, 5.0, 6.0, 7.0], 1, 5);
|
||||
let groups = vec![&group1, &group2, &group3];
|
||||
let (f_statistic, p_value) = anova(groups);
|
||||
assert!(f_statistic > 0.0);
|
||||
assert!(p_value > 0.0 && p_value < 1.0);
|
||||
}
|
||||
}
|
||||
22
src/compute/stats/mod.rs
Normal file
22
src/compute/stats/mod.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
//! Statistical routines for matrices.
|
||||
//!
|
||||
//! Functions are grouped into submodules for descriptive statistics,
|
||||
//! correlations, probability distributions and basic inferential tests.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::compute::stats;
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let m = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
|
||||
//! let cov = stats::covariance(&m, &m);
|
||||
//! assert!((cov - 1.25).abs() < 1e-8);
|
||||
//! ```
|
||||
pub mod correlation;
|
||||
pub mod descriptive;
|
||||
pub mod distributions;
|
||||
pub mod inferential;
|
||||
|
||||
pub use correlation::*;
|
||||
pub use descriptive::*;
|
||||
pub use distributions::*;
|
||||
pub use inferential::*;
|
||||
@@ -1,3 +1,19 @@
|
||||
//! Core data-frame structures such as [`Frame`] and [`RowIndex`].
|
||||
//!
|
||||
//! The [`Frame`] type stores column-labelled data with an optional row index
|
||||
//! and builds upon the [`crate::matrix::Matrix`] type.
|
||||
//!
|
||||
//! # Examples
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::frame::{Frame, RowIndex};
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let data = Matrix::from_cols(vec![vec![1, 2], vec![3, 4]]);
|
||||
//! let frame = Frame::new(data, vec!["L", "R"], Some(RowIndex::Int(vec![10, 20])));
|
||||
//! assert_eq!(frame.columns(), &["L", "R"]);
|
||||
//! assert_eq!(frame.index(), &RowIndex::Int(vec![10, 20]));
|
||||
//! ```
|
||||
use crate::matrix::Matrix;
|
||||
use chrono::NaiveDate;
|
||||
use std::collections::HashMap;
|
||||
@@ -232,11 +248,18 @@ impl<T: Clone + PartialEq> Frame<T> {
|
||||
}
|
||||
(RowIndex::Date(vals), RowIndexLookup::Date(lookup))
|
||||
}
|
||||
Some(RowIndex::Range(_)) => {
|
||||
Some(RowIndex::Range(ref r)) => {
|
||||
// If the length of the range does not match the number of rows, panic.
|
||||
if r.end.saturating_sub(r.start) != num_rows {
|
||||
panic!(
|
||||
"Frame::new: Cannot explicitly provide a Range index. Use None for default range."
|
||||
"Frame::new: Range index length ({}) mismatch matrix rows ({})",
|
||||
r.end.saturating_sub(r.start),
|
||||
num_rows
|
||||
);
|
||||
}
|
||||
// return the range as is.
|
||||
(RowIndex::Range(r.clone()), RowIndexLookup::None)
|
||||
}
|
||||
None => {
|
||||
// Default to a sequential range index.
|
||||
(RowIndex::Range(0..num_rows), RowIndexLookup::None)
|
||||
@@ -464,6 +487,11 @@ impl<T: Clone + PartialEq> Frame<T> {
|
||||
deleted_data
|
||||
}
|
||||
|
||||
/// Returns a new `Matrix` that is the transpose of the current frame's matrix.
|
||||
pub fn transpose(&self) -> Matrix<T> {
|
||||
self.matrix.transpose()
|
||||
}
|
||||
|
||||
/// Sorts columns alphabetically by name, preserving data associations.
|
||||
pub fn sort_columns(&mut self) {
|
||||
let n = self.column_names.len();
|
||||
@@ -500,6 +528,45 @@ impl<T: Clone + PartialEq> Frame<T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn frame_map(&self, f: impl Fn(&T) -> T) -> Frame<T> {
|
||||
Frame::new(
|
||||
Matrix::from_vec(
|
||||
self.matrix.data().iter().map(f).collect(),
|
||||
self.matrix.rows(),
|
||||
self.matrix.cols(),
|
||||
),
|
||||
self.column_names.clone(),
|
||||
Some(self.index.clone()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn frame_zip(&self, other: &Frame<T>, f: impl Fn(&T, &T) -> T) -> Frame<T> {
|
||||
if self.rows() != other.rows() || self.cols() != other.cols() {
|
||||
panic!(
|
||||
"Frame::frame_zip: incompatible dimensions (self: {}x{}, other: {}x{})",
|
||||
self.rows(),
|
||||
self.cols(),
|
||||
other.rows(),
|
||||
other.cols()
|
||||
);
|
||||
}
|
||||
|
||||
Frame::new(
|
||||
Matrix::from_vec(
|
||||
self.matrix
|
||||
.data()
|
||||
.iter()
|
||||
.zip(other.matrix.data())
|
||||
.map(|(a, b)| f(a, b))
|
||||
.collect(),
|
||||
self.rows(),
|
||||
self.cols(),
|
||||
),
|
||||
self.column_names.clone(),
|
||||
Some(self.index.clone()),
|
||||
)
|
||||
}
|
||||
|
||||
// Internal helpers
|
||||
|
||||
/// Rebuilds the column lookup map to match the current `column_names` ordering.
|
||||
@@ -781,14 +848,13 @@ impl<T: Clone + PartialEq> IndexMut<&str> for Frame<T> {
|
||||
/// Panics if column labels or row indices differ between operands.
|
||||
macro_rules! impl_elementwise_frame_op {
|
||||
($OpTrait:ident, $method:ident) => {
|
||||
// &Frame<T> $OpTrait &Frame<T>
|
||||
impl<'a, 'b, T> std::ops::$OpTrait<&'b Frame<T>> for &'a Frame<T>
|
||||
where
|
||||
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
|
||||
{
|
||||
type Output = Frame<T>;
|
||||
|
||||
fn $method(self, rhs: &'b Frame<T>) -> Frame<T> {
|
||||
// Verify matching schema
|
||||
if self.column_names != rhs.column_names {
|
||||
panic!(
|
||||
"Element-wise {}: column names do not match. Left: {:?}, Right: {:?}",
|
||||
@@ -805,21 +871,47 @@ macro_rules! impl_elementwise_frame_op {
|
||||
rhs.index
|
||||
);
|
||||
}
|
||||
|
||||
// Apply the matrix operation
|
||||
let result_matrix = (&self.matrix).$method(&rhs.matrix);
|
||||
|
||||
// Determine index for the result
|
||||
let new_index = match self.index {
|
||||
RowIndex::Range(_) => None,
|
||||
_ => Some(self.index.clone()),
|
||||
};
|
||||
|
||||
Frame::new(result_matrix, self.column_names.clone(), new_index)
|
||||
}
|
||||
}
|
||||
// Frame<T> $OpTrait &Frame<T>
|
||||
impl<'b, T> std::ops::$OpTrait<&'b Frame<T>> for Frame<T>
|
||||
where
|
||||
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
|
||||
{
|
||||
type Output = Frame<T>;
|
||||
fn $method(self, rhs: &'b Frame<T>) -> Frame<T> {
|
||||
(&self).$method(rhs)
|
||||
}
|
||||
}
|
||||
// &Frame<T> $OpTrait Frame<T>
|
||||
impl<'a, T> std::ops::$OpTrait<Frame<T>> for &'a Frame<T>
|
||||
where
|
||||
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
|
||||
{
|
||||
type Output = Frame<T>;
|
||||
fn $method(self, rhs: Frame<T>) -> Frame<T> {
|
||||
self.$method(&rhs)
|
||||
}
|
||||
}
|
||||
// Frame<T> $OpTrait Frame<T>
|
||||
impl<T> std::ops::$OpTrait<Frame<T>> for Frame<T>
|
||||
where
|
||||
T: Clone + PartialEq + std::ops::$OpTrait<Output = T>,
|
||||
{
|
||||
type Output = Frame<T>;
|
||||
fn $method(self, rhs: Frame<T>) -> Frame<T> {
|
||||
(&self).$method(&rhs)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_elementwise_frame_op!(Add, add);
|
||||
impl_elementwise_frame_op!(Sub, sub);
|
||||
impl_elementwise_frame_op!(Mul, mul);
|
||||
@@ -830,11 +922,10 @@ impl_elementwise_frame_op!(Div, div);
|
||||
/// Panics if column labels or row indices differ between operands.
|
||||
macro_rules! impl_bitwise_frame_op {
|
||||
($OpTrait:ident, $method:ident) => {
|
||||
// &Frame<bool> $OpTrait &Frame<bool>
|
||||
impl<'a, 'b> std::ops::$OpTrait<&'b Frame<bool>> for &'a Frame<bool> {
|
||||
type Output = Frame<bool>;
|
||||
|
||||
fn $method(self, rhs: &'b Frame<bool>) -> Frame<bool> {
|
||||
// Verify matching schema
|
||||
if self.column_names != rhs.column_names {
|
||||
panic!(
|
||||
"Bitwise {}: column names do not match. Left: {:?}, Right: {:?}",
|
||||
@@ -851,25 +942,43 @@ macro_rules! impl_bitwise_frame_op {
|
||||
rhs.index
|
||||
);
|
||||
}
|
||||
|
||||
// Apply the matrix operation
|
||||
let result_matrix = (&self.matrix).$method(&rhs.matrix);
|
||||
|
||||
// Determine index for the result
|
||||
let new_index = match self.index {
|
||||
RowIndex::Range(_) => None,
|
||||
_ => Some(self.index.clone()),
|
||||
};
|
||||
|
||||
Frame::new(result_matrix, self.column_names.clone(), new_index)
|
||||
}
|
||||
}
|
||||
// Frame<bool> $OpTrait &Frame<bool>
|
||||
impl<'b> std::ops::$OpTrait<&'b Frame<bool>> for Frame<bool> {
|
||||
type Output = Frame<bool>;
|
||||
fn $method(self, rhs: &'b Frame<bool>) -> Frame<bool> {
|
||||
(&self).$method(rhs)
|
||||
}
|
||||
}
|
||||
// &Frame<bool> $OpTrait Frame<bool>
|
||||
impl<'a> std::ops::$OpTrait<Frame<bool>> for &'a Frame<bool> {
|
||||
type Output = Frame<bool>;
|
||||
fn $method(self, rhs: Frame<bool>) -> Frame<bool> {
|
||||
self.$method(&rhs)
|
||||
}
|
||||
}
|
||||
// Frame<bool> $OpTrait Frame<bool>
|
||||
impl std::ops::$OpTrait<Frame<bool>> for Frame<bool> {
|
||||
type Output = Frame<bool>;
|
||||
fn $method(self, rhs: Frame<bool>) -> Frame<bool> {
|
||||
(&self).$method(&rhs)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_bitwise_frame_op!(BitAnd, bitand);
|
||||
impl_bitwise_frame_op!(BitOr, bitor);
|
||||
impl_bitwise_frame_op!(BitXor, bitxor);
|
||||
|
||||
/* ---------- Logical NOT ---------- */
|
||||
/// Implements logical NOT (`!`) for `Frame<bool>`, consuming the frame.
|
||||
impl Not for Frame<bool> {
|
||||
type Output = Frame<bool>;
|
||||
@@ -888,12 +997,30 @@ impl Not for Frame<bool> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements logical NOT (`!`) for `&Frame<bool>`, borrowing the frame.
|
||||
impl Not for &Frame<bool> {
|
||||
type Output = Frame<bool>;
|
||||
|
||||
fn not(self) -> Frame<bool> {
|
||||
// Apply NOT to the underlying matrix
|
||||
let result_matrix = !&self.matrix;
|
||||
|
||||
// Determine index for the result
|
||||
let new_index = match self.index {
|
||||
RowIndex::Range(_) => None,
|
||||
_ => Some(self.index.clone()),
|
||||
};
|
||||
|
||||
Frame::new(result_matrix, self.column_names.clone(), new_index)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
// Assume Matrix is available from crate::matrix or similar
|
||||
use crate::matrix::Matrix;
|
||||
use crate::matrix::{BoolOps, Matrix};
|
||||
use chrono::NaiveDate;
|
||||
// HashMap needed for direct inspection in tests if required
|
||||
use std::collections::HashMap;
|
||||
@@ -1057,10 +1184,10 @@ mod tests {
|
||||
Frame::new(matrix, vec!["X", "Y"], Some(index));
|
||||
}
|
||||
#[test]
|
||||
#[should_panic(expected = "Cannot explicitly provide a Range index")]
|
||||
fn frame_new_panic_explicit_range() {
|
||||
let matrix = create_test_matrix_f64();
|
||||
let index = RowIndex::Range(0..3); // User cannot provide Range directly
|
||||
#[should_panic(expected = "Frame::new: Range index length (4) mismatch matrix rows (3)")]
|
||||
fn frame_new_panic_invalid_explicit_range_index() {
|
||||
let matrix = create_test_matrix_f64(); // 3 rows
|
||||
let index = RowIndex::Range(0..4); // Range 0..4 but only 3 rows
|
||||
Frame::new(matrix, vec!["A", "B"], Some(index));
|
||||
}
|
||||
|
||||
@@ -1349,7 +1476,7 @@ mod tests {
|
||||
fn test_row_view_name_panic() {
|
||||
let frame = create_test_frame_f64();
|
||||
let row_view = frame.get_row(0);
|
||||
let _ = row_view["C"]; // Access non-existent column name
|
||||
let _ = row_view["C"]; // Access non-existent column Z
|
||||
}
|
||||
#[test]
|
||||
#[should_panic(expected = "column index 3 out of bounds")] // Check specific message
|
||||
@@ -1581,6 +1708,45 @@ mod tests {
|
||||
assert_eq!(frame1.columns(), &["Z"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_map() {
|
||||
let frame = create_test_frame_f64(); // A=[1,2,3], B=[4,5,6]
|
||||
let mapped_frame = frame.frame_map(|x| x * 2.0); // Multiply each value by 2.0
|
||||
assert_eq!(mapped_frame.columns(), frame.columns());
|
||||
assert_eq!(mapped_frame.index(), frame.index());
|
||||
assert!((mapped_frame["A"][0] - 2.0).abs() < FLOAT_TOLERANCE);
|
||||
assert!((mapped_frame["A"][1] - 4.0).abs() < FLOAT_TOLERANCE);
|
||||
assert!((mapped_frame["A"][2] - 6.0).abs() < FLOAT_TOLERANCE);
|
||||
assert!((mapped_frame["B"][0] - 8.0).abs() < FLOAT_TOLERANCE);
|
||||
assert!((mapped_frame["B"][1] - 10.0).abs() < FLOAT_TOLERANCE);
|
||||
assert!((mapped_frame["B"][2] - 12.0).abs() < FLOAT_TOLERANCE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_zip() {
|
||||
let f1 = create_test_frame_f64(); // A=[1,2,3], B=[4,5,6]
|
||||
let f2 = create_test_frame_f64_alt(); // A=[0.1,0.2,0.3], B=[0.4,0.5,0.6]
|
||||
let zipped_frame = f1.frame_zip(&f2, |x, y| x + y); // Element-wise addition
|
||||
assert_eq!(zipped_frame.columns(), f1.columns());
|
||||
assert_eq!(zipped_frame.index(), f1.index());
|
||||
assert!((zipped_frame["A"][0] - 1.1).abs() < FLOAT_TOLERANCE);
|
||||
assert!((zipped_frame["A"][1] - 2.2).abs() < FLOAT_TOLERANCE);
|
||||
assert!((zipped_frame["A"][2] - 3.3).abs() < FLOAT_TOLERANCE);
|
||||
assert!((zipped_frame["B"][0] - 4.4).abs() < FLOAT_TOLERANCE);
|
||||
assert!((zipped_frame["B"][1] - 5.5).abs() < FLOAT_TOLERANCE);
|
||||
assert!((zipped_frame["B"][2] - 6.6).abs() < FLOAT_TOLERANCE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Frame::frame_zip: incompatible dimensions (self: 3x1, other: 3x2)")]
|
||||
fn test_frame_zip_panic() {
|
||||
let mut f1 = create_test_frame_f64();
|
||||
let f2 = create_test_frame_f64_alt();
|
||||
f1.delete_column("B");
|
||||
|
||||
f1.frame_zip(&f2, |x, y| x + y); // Should panic due to different column counts
|
||||
}
|
||||
|
||||
// --- Element-wise Arithmetic Ops Tests ---
|
||||
#[test]
|
||||
fn test_frame_arithmetic_ops_f64() {
|
||||
@@ -1666,6 +1832,79 @@ mod tests {
|
||||
assert_eq!(frame_div["Y"], vec![10, -10]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tests_for_frame_arithmetic_ops() {
|
||||
let ops: Vec<(
|
||||
&str,
|
||||
fn(&Frame<f64>, &Frame<f64>) -> Frame<f64>,
|
||||
fn(&Frame<f64>, &Frame<f64>) -> Frame<f64>,
|
||||
)> = vec![
|
||||
("addition", |a, b| a + b, |a, b| (&*a) + (&*b)),
|
||||
("subtraction", |a, b| a - b, |a, b| (&*a) - (&*b)),
|
||||
("multiplication", |a, b| a * b, |a, b| (&*a) * (&*b)),
|
||||
("division", |a, b| a / b, |a, b| (&*a) / (&*b)),
|
||||
];
|
||||
|
||||
for (op_name, owned_op, ref_op) in ops {
|
||||
let f1 = create_test_frame_f64();
|
||||
let f2 = create_test_frame_f64_alt();
|
||||
let result_owned = owned_op(&f1, &f2);
|
||||
let expected = ref_op(&f1, &f2);
|
||||
|
||||
assert_eq!(
|
||||
result_owned.columns(),
|
||||
f1.columns(),
|
||||
"Column mismatch for {}",
|
||||
op_name
|
||||
);
|
||||
assert_eq!(
|
||||
result_owned.index(),
|
||||
f1.index(),
|
||||
"Index mismatch for {}",
|
||||
op_name
|
||||
);
|
||||
|
||||
let bool_mat = result_owned.matrix().eq_elem(expected.matrix().clone());
|
||||
assert!(bool_mat.all(), "Element-wise {} failed", op_name);
|
||||
}
|
||||
}
|
||||
|
||||
// test not , and or on frame
|
||||
#[test]
|
||||
fn tests_for_frame_bool_ops() {
|
||||
let ops: Vec<(
|
||||
&str,
|
||||
fn(&Frame<bool>, &Frame<bool>) -> Frame<bool>,
|
||||
fn(&Frame<bool>, &Frame<bool>) -> Frame<bool>,
|
||||
)> = vec![
|
||||
("and", |a, b| a & b, |a, b| (&*a) & (&*b)),
|
||||
("or", |a, b| a | b, |a, b| (&*a) | (&*b)),
|
||||
("xor", |a, b| a ^ b, |a, b| (&*a) ^ (&*b)),
|
||||
];
|
||||
for (op_name, owned_op, ref_op) in ops {
|
||||
let f1 = create_test_frame_bool();
|
||||
let f2 = create_test_frame_bool_alt();
|
||||
let result_owned = owned_op(&f1, &f2);
|
||||
let expected = ref_op(&f1, &f2);
|
||||
|
||||
assert_eq!(
|
||||
result_owned.columns(),
|
||||
f1.columns(),
|
||||
"Column mismatch for {}",
|
||||
op_name
|
||||
);
|
||||
assert_eq!(
|
||||
result_owned.index(),
|
||||
f1.index(),
|
||||
"Index mismatch for {}",
|
||||
op_name
|
||||
);
|
||||
|
||||
let bool_mat = result_owned.matrix().eq_elem(expected.matrix().clone());
|
||||
assert!(bool_mat.all(), "Element-wise {} failed", op_name);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_arithmetic_ops_date_index() {
|
||||
let dates = vec![d(2024, 1, 1), d(2024, 1, 2)];
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
//! High-level interface for working with columnar data and row indices.
|
||||
//!
|
||||
//! The [`Frame`](crate::frame::Frame) type combines a matrix with column labels and a typed row
|
||||
//! index, similar to data frames in other data-analysis libraries.
|
||||
//!
|
||||
//! # Examples
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::frame::{Frame, RowIndex};
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! // Build a frame from two columns labelled "A" and "B".
|
||||
//! let data = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
||||
//! let frame = Frame::new(data, vec!["A", "B"], None);
|
||||
//!
|
||||
//! assert_eq!(frame["A"], vec![1.0, 2.0]);
|
||||
//! assert_eq!(frame.index(), &RowIndex::Range(0..2));
|
||||
//! ```
|
||||
pub mod base;
|
||||
pub mod ops;
|
||||
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
//! Trait implementations that allow [`Frame`] to reuse matrix operations.
|
||||
//!
|
||||
//! These modules forward numeric and boolean aggregation methods from the
|
||||
//! underlying [`Matrix`](crate::matrix::Matrix) type so that they can be called
|
||||
//! directly on a [`Frame`].
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::frame::Frame;
|
||||
//! use rustframe::matrix::{Matrix, SeriesOps};
|
||||
//!
|
||||
//! let frame = Frame::new(Matrix::from_cols(vec![vec![1.0, 2.0]]), vec!["A"], None);
|
||||
//! assert_eq!(frame.sum_vertical(), vec![3.0]);
|
||||
//! ```
|
||||
use crate::frame::Frame;
|
||||
use crate::matrix::{Axis, BoolMatrix, BoolOps, FloatMatrix, SeriesOps};
|
||||
|
||||
@@ -21,6 +34,28 @@ impl SeriesOps for Frame<f64> {
|
||||
self.matrix().apply_axis(axis, f)
|
||||
}
|
||||
|
||||
fn map<F>(&self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64) -> f64,
|
||||
{
|
||||
self.matrix().map(f)
|
||||
}
|
||||
|
||||
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64, f64) -> f64,
|
||||
{
|
||||
self.matrix().zip(other.matrix(), f)
|
||||
}
|
||||
|
||||
fn matrix_mul(&self, other: &Self) -> FloatMatrix {
|
||||
self.matrix().matrix_mul(other.matrix())
|
||||
}
|
||||
|
||||
fn dot(&self, other: &Self) -> FloatMatrix {
|
||||
self.matrix().dot(other.matrix())
|
||||
}
|
||||
|
||||
delegate_to_matrix!(
|
||||
sum_vertical -> Vec<f64>,
|
||||
sum_horizontal -> Vec<f64>,
|
||||
@@ -106,7 +141,7 @@ mod tests {
|
||||
let col_names = vec!["A".to_string(), "B".to_string()];
|
||||
let frame = Frame::new(
|
||||
Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]),
|
||||
col_names,
|
||||
col_names.clone(),
|
||||
None,
|
||||
);
|
||||
assert_eq!(frame.sum_vertical(), frame.matrix().sum_vertical());
|
||||
@@ -128,6 +163,33 @@ mod tests {
|
||||
);
|
||||
assert_eq!(frame.is_nan(), frame.matrix().is_nan());
|
||||
assert_eq!(frame.apply_axis(Axis::Row, |x| x[0] + x[1]), vec![4.0, 6.0]);
|
||||
|
||||
assert_eq!(
|
||||
frame.matrix_mul(&frame),
|
||||
frame.matrix().matrix_mul(&frame.matrix())
|
||||
);
|
||||
assert_eq!(frame.dot(&frame), frame.matrix().dot(&frame.matrix()));
|
||||
|
||||
// test transpose - returns a matrix.
|
||||
let frame_transposed_mat = frame.transpose();
|
||||
let frame_mat_transposed = frame.matrix().transpose();
|
||||
assert_eq!(frame_transposed_mat, frame_mat_transposed);
|
||||
assert_eq!(frame.matrix(), &frame.matrix().transpose().transpose());
|
||||
|
||||
// test map
|
||||
let mapped_frame = frame.map(|x| x * 2.0);
|
||||
let expected_matrix = frame.matrix().map(|x| x * 2.0);
|
||||
assert_eq!(mapped_frame, expected_matrix);
|
||||
|
||||
// test zip
|
||||
let other_frame = Frame::new(
|
||||
Matrix::from_cols(vec![vec![5.0, 6.0], vec![7.0, 8.0]]),
|
||||
col_names.clone(),
|
||||
None,
|
||||
);
|
||||
let zipped_frame = frame.zip(&other_frame, |x, y| x + y);
|
||||
let expected_zipped_matrix = frame.matrix().zip(other_frame.matrix(), |x, y| x + y);
|
||||
assert_eq!(zipped_frame, expected_zipped_matrix);
|
||||
}
|
||||
#[test]
|
||||
|
||||
|
||||
@@ -8,3 +8,9 @@ pub mod frame;
|
||||
|
||||
/// Documentation for the [`crate::utils`] module.
|
||||
pub mod utils;
|
||||
|
||||
/// Documentation for the [`crate::compute`] module.
|
||||
pub mod compute;
|
||||
|
||||
/// Documentation for the [`crate::random`] module.
|
||||
pub mod random;
|
||||
|
||||
@@ -1,3 +1,14 @@
|
||||
//! Logical reductions for boolean matrices.
|
||||
//!
|
||||
//! The [`BoolOps`] trait mirrors common boolean aggregations such as `any` and
|
||||
//! `all` over rows or columns of a [`BoolMatrix`].
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::matrix::{BoolMatrix, BoolOps};
|
||||
//!
|
||||
//! let m = BoolMatrix::from_vec(vec![true, false], 2, 1);
|
||||
//! assert!(m.any());
|
||||
//! ```
|
||||
use crate::matrix::{Axis, BoolMatrix};
|
||||
|
||||
/// Boolean operations on `Matrix<bool>`
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,22 @@
|
||||
//! Core matrix types and operations.
|
||||
//!
|
||||
//! The [`Matrix`](crate::matrix::Matrix) struct provides a simple column‑major 2D array with a
|
||||
//! suite of numeric helpers. Additional traits like [`SeriesOps`](crate::matrix::SeriesOps) and
|
||||
//! [`BoolOps`](crate::matrix::BoolOps) extend functionality for common statistics and logical reductions.
|
||||
//!
|
||||
//! # Examples
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::matrix::Matrix;
|
||||
//!
|
||||
//! let m = Matrix::from_cols(vec![vec![1, 2], vec![3, 4]]);
|
||||
//! assert_eq!(m.shape(), (2, 2));
|
||||
//! assert_eq!(m[(0,1)], 3);
|
||||
//! ```
|
||||
pub mod boolops;
|
||||
pub mod mat;
|
||||
pub mod seriesops;
|
||||
pub mod boolops;
|
||||
|
||||
pub use boolops::*;
|
||||
pub use mat::*;
|
||||
pub use seriesops::*;
|
||||
pub use boolops::*;
|
||||
@@ -1,3 +1,14 @@
|
||||
//! Numeric reductions and transformations over matrix axes.
|
||||
//!
|
||||
//! [`SeriesOps`] provides methods like [`SeriesOps::sum_vertical`] or
|
||||
//! [`SeriesOps::map`] that operate on [`FloatMatrix`] values.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::matrix::{Matrix, SeriesOps};
|
||||
//!
|
||||
//! let m = Matrix::from_cols(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
||||
//! assert_eq!(m.sum_horizontal(), vec![4.0, 6.0]);
|
||||
//! ```
|
||||
use crate::matrix::{Axis, BoolMatrix, FloatMatrix};
|
||||
|
||||
/// "Series-like" helpers that work along a single axis.
|
||||
@@ -12,6 +23,17 @@ pub trait SeriesOps {
|
||||
where
|
||||
F: FnMut(&[f64]) -> U;
|
||||
|
||||
fn map<F>(&self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64) -> f64;
|
||||
|
||||
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64, f64) -> f64;
|
||||
|
||||
fn matrix_mul(&self, other: &Self) -> FloatMatrix;
|
||||
fn dot(&self, other: &Self) -> FloatMatrix;
|
||||
|
||||
fn sum_vertical(&self) -> Vec<f64>;
|
||||
fn sum_horizontal(&self) -> Vec<f64>;
|
||||
|
||||
@@ -139,23 +161,88 @@ impl SeriesOps for FloatMatrix {
|
||||
let data = self.data().iter().map(|v| v.is_nan()).collect();
|
||||
BoolMatrix::from_vec(data, self.rows(), self.cols())
|
||||
}
|
||||
}
|
||||
|
||||
fn matrix_mul(&self, other: &Self) -> FloatMatrix {
|
||||
let (m, n) = (self.rows(), self.cols());
|
||||
let (n2, p) = (other.rows(), other.cols());
|
||||
assert_eq!(
|
||||
n, n2,
|
||||
"Cannot multiply: left is {}x{}, right is {}x{}",
|
||||
m, n, n2, p
|
||||
);
|
||||
|
||||
// Column-major addressing: element (row i, col j) lives at j * m + i
|
||||
let mut data = vec![0.0; m * p];
|
||||
for i in 0..m {
|
||||
for j in 0..p {
|
||||
let mut sum = 0.0;
|
||||
for k in 0..n {
|
||||
sum += self[(i, k)] * other[(k, j)];
|
||||
}
|
||||
data[j * m + i] = sum; // <-- fixed index
|
||||
}
|
||||
}
|
||||
FloatMatrix::from_vec(data, m, p)
|
||||
}
|
||||
fn dot(&self, other: &Self) -> FloatMatrix {
|
||||
self.matrix_mul(other)
|
||||
}
|
||||
|
||||
fn map<F>(&self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64) -> f64,
|
||||
{
|
||||
let data = self.data().iter().map(|&v| f(v)).collect::<Vec<_>>();
|
||||
FloatMatrix::from_vec(data, self.rows(), self.cols())
|
||||
}
|
||||
|
||||
fn zip<F>(&self, other: &Self, f: F) -> FloatMatrix
|
||||
where
|
||||
F: Fn(f64, f64) -> f64,
|
||||
{
|
||||
assert!(
|
||||
self.rows() == other.rows() && self.cols() == other.cols(),
|
||||
"Matrix dimensions mismatch: left is {}x{}, right is {}x{}",
|
||||
self.rows(),
|
||||
self.cols(),
|
||||
other.rows(),
|
||||
other.cols()
|
||||
);
|
||||
|
||||
let data = self
|
||||
.data()
|
||||
.iter()
|
||||
.zip(other.data().iter())
|
||||
.map(|(&a, &b)| f(a, b))
|
||||
.collect();
|
||||
crate::matrix::Matrix::from_vec(data, self.rows(), self.cols())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
// Helper function to create a FloatMatrix for SeriesOps testing
|
||||
fn create_float_test_matrix() -> FloatMatrix {
|
||||
// 3x3 matrix (column-major) with some NaNs
|
||||
// 1.0 4.0 7.0
|
||||
// 2.0 NaN 8.0
|
||||
// 3.0 6.0 NaN
|
||||
// 3x3 column-major matrix containing a few NaN values
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, f64::NAN, 6.0, 7.0, 8.0, f64::NAN];
|
||||
FloatMatrix::from_vec(data, 3, 3)
|
||||
}
|
||||
|
||||
fn create_float_test_matrix_4x4() -> FloatMatrix {
|
||||
// 4x4 column-major matrix with NaNs inserted at positions where index % 5 == 0
|
||||
// first make array with 16 elements
|
||||
FloatMatrix::from_vec(
|
||||
(0..16)
|
||||
.map(|i| if i % 5 == 0 { f64::NAN } else { i as f64 })
|
||||
.collect(),
|
||||
4,
|
||||
4,
|
||||
)
|
||||
}
|
||||
|
||||
// --- Tests for SeriesOps (FloatMatrix) ---
|
||||
|
||||
#[test]
|
||||
@@ -256,6 +343,90 @@ mod tests {
|
||||
assert_eq!(matrix.is_nan(), expected_matrix);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_series_ops_matrix_mul() {
|
||||
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
|
||||
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix
|
||||
// result should be: 23, 34, 31, 46
|
||||
let expected = FloatMatrix::from_vec(vec![23.0, 34.0, 31.0, 46.0], 2, 2);
|
||||
assert_eq!(a.matrix_mul(&b), expected);
|
||||
|
||||
assert_eq!(a.dot(&b), a.matrix_mul(&b)); // dot should be the same as matrix_mul for FloatMatrix
|
||||
}
|
||||
#[test]
|
||||
fn test_series_ops_matrix_mul_with_nans() {
|
||||
let a = create_float_test_matrix(); // 3x3 matrix with some NaNs
|
||||
let b = create_float_test_matrix(); // 3x3 matrix with some NaNs
|
||||
|
||||
let mut result_vec = Vec::new();
|
||||
result_vec.push(30.0);
|
||||
for _ in 1..9 {
|
||||
result_vec.push(f64::NAN);
|
||||
}
|
||||
let expected = FloatMatrix::from_vec(result_vec, 3, 3);
|
||||
|
||||
let result = a.matrix_mul(&b);
|
||||
|
||||
assert_eq!(result.is_nan(), expected.is_nan());
|
||||
assert_eq!(
|
||||
result.count_nan_horizontal(),
|
||||
expected.count_nan_horizontal()
|
||||
);
|
||||
assert_eq!(result.count_nan_vertical(), expected.count_nan_vertical());
|
||||
assert_eq!(result[(0, 0)], expected[(0, 0)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Cannot multiply: left is 3x3, right is 4x4")]
|
||||
fn test_series_ops_matrix_mul_errors() {
|
||||
let a = create_float_test_matrix();
|
||||
let b = create_float_test_matrix_4x4();
|
||||
|
||||
a.dot(&b); // This should panic due to dimension mismatch
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_series_ops_map() {
|
||||
let matrix = create_float_test_matrix();
|
||||
// Map function to double each value
|
||||
let mapped_matrix = matrix.map(|x| x * 2.0);
|
||||
// Expected data after mapping
|
||||
let expected_data = vec![2.0, 4.0, 6.0, 8.0, f64::NAN, 12.0, 14.0, 16.0, f64::NAN];
|
||||
let expected_matrix = FloatMatrix::from_vec(expected_data, 3, 3);
|
||||
// assert_eq!(mapped_matrix, expected_matrix);
|
||||
for i in 0..mapped_matrix.data().len() {
|
||||
// if not nan, check equality
|
||||
if !mapped_matrix.data()[i].is_nan() {
|
||||
assert_eq!(mapped_matrix.data()[i], expected_matrix.data()[i]);
|
||||
} else {
|
||||
assert!(mapped_matrix.data()[i].is_nan());
|
||||
assert!(expected_matrix.data()[i].is_nan());
|
||||
}
|
||||
}
|
||||
assert_eq!(mapped_matrix.rows(), expected_matrix.rows());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_series_ops_zip() {
|
||||
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
|
||||
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0], 2, 2); // 2x2 matrix
|
||||
// Zip function to add corresponding elements
|
||||
let zipped_matrix = a.zip(&b, |x, y| x + y);
|
||||
// Expected data after zipping
|
||||
let expected_data = vec![6.0, 8.0, 10.0, 12.0];
|
||||
let expected_matrix = FloatMatrix::from_vec(expected_data, 2, 2);
|
||||
assert_eq!(zipped_matrix, expected_matrix);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Matrix dimensions mismatch: left is 2x2, right is 3x2")]
|
||||
fn test_series_ops_zip_panic() {
|
||||
let a = FloatMatrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2); // 2x2 matrix
|
||||
let b = FloatMatrix::from_vec(vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 3, 2); // 3x2 matrix
|
||||
// This should panic due to dimension mismatch
|
||||
a.zip(&b, |x, y| x + y);
|
||||
}
|
||||
|
||||
// --- Edge Cases for SeriesOps ---
|
||||
|
||||
#[test]
|
||||
|
||||
237
src/random/crypto.rs
Normal file
237
src/random/crypto.rs
Normal file
@@ -0,0 +1,237 @@
|
||||
//! Cryptographically secure random number generator.
|
||||
//!
|
||||
//! On Unix systems this reads from `/dev/urandom`; on Windows it uses the
|
||||
//! system's preferred CNG provider.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::random::{crypto_rng, Rng};
|
||||
//! let mut rng = crypto_rng();
|
||||
//! let _v = rng.next_u64();
|
||||
//! ```
|
||||
#[cfg(unix)]
|
||||
use std::{fs::File, io::Read};
|
||||
|
||||
use crate::random::Rng;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub struct CryptoRng {
|
||||
file: File,
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl CryptoRng {
|
||||
/// Open `/dev/urandom`.
|
||||
pub fn new() -> Self {
|
||||
let file = File::open("/dev/urandom").expect("failed to open /dev/urandom");
|
||||
Self { file }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl Rng for CryptoRng {
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
let mut buf = [0u8; 8];
|
||||
self.file
|
||||
.read_exact(&mut buf)
|
||||
.expect("failed reading from /dev/urandom");
|
||||
u64::from_ne_bytes(buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
pub struct CryptoRng;
|
||||
|
||||
#[cfg(windows)]
|
||||
impl CryptoRng {
|
||||
/// No handle is needed on Windows.
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
impl Rng for CryptoRng {
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
let mut buf = [0u8; 8];
|
||||
win_fill(&mut buf).expect("BCryptGenRandom failed");
|
||||
u64::from_ne_bytes(buf)
|
||||
}
|
||||
}
|
||||
|
||||
/// Fill `buf` with cryptographically secure random bytes using CNG.
|
||||
///
|
||||
/// * `BCryptGenRandom(NULL, buf, len, BCRYPT_USE_SYSTEM_PREFERRED_RNG)`
|
||||
/// asks the OS for its system‑preferred DRBG (CTR_DRBG on modern
|
||||
/// Windows).
|
||||
#[cfg(windows)]
|
||||
fn win_fill(buf: &mut [u8]) -> Result<(), ()> {
|
||||
use core::ffi::c_void;
|
||||
|
||||
type BcryptAlgHandle = *mut c_void;
|
||||
type NTSTATUS = i32;
|
||||
|
||||
const BCRYPT_USE_SYSTEM_PREFERRED_RNG: u32 = 0x0000_0002;
|
||||
|
||||
#[link(name = "bcrypt")]
|
||||
extern "system" {
|
||||
fn BCryptGenRandom(
|
||||
hAlgorithm: BcryptAlgHandle,
|
||||
pbBuffer: *mut u8,
|
||||
cbBuffer: u32,
|
||||
dwFlags: u32,
|
||||
) -> NTSTATUS;
|
||||
}
|
||||
|
||||
// NT_SUCCESS(status) == status >= 0
|
||||
let status = unsafe {
|
||||
BCryptGenRandom(
|
||||
core::ptr::null_mut(),
|
||||
buf.as_mut_ptr(),
|
||||
buf.len() as u32,
|
||||
BCRYPT_USE_SYSTEM_PREFERRED_RNG,
|
||||
)
|
||||
};
|
||||
|
||||
if status >= 0 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience constructor for [`CryptoRng`].
|
||||
pub fn crypto_rng() -> CryptoRng {
|
||||
CryptoRng::new()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::random::Rng;
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[test]
|
||||
fn test_crypto_rng_nonzero() {
|
||||
let mut rng = CryptoRng::new();
|
||||
let mut all_same = true;
|
||||
let mut prev = rng.next_u64();
|
||||
for _ in 0..5 {
|
||||
let val = rng.next_u64();
|
||||
if val != prev {
|
||||
all_same = false;
|
||||
}
|
||||
prev = val;
|
||||
}
|
||||
assert!(!all_same, "CryptoRng produced identical values");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_rng_variation_large() {
|
||||
let mut rng = CryptoRng::new();
|
||||
let mut values = HashSet::new();
|
||||
for _ in 0..100 {
|
||||
values.insert(rng.next_u64());
|
||||
}
|
||||
assert!(values.len() > 90, "CryptoRng output not varied enough");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_rng_random_range_uniform() {
|
||||
let mut rng = CryptoRng::new();
|
||||
let mut counts = [0usize; 10];
|
||||
for _ in 0..1000 {
|
||||
let v = rng.random_range(0..10usize);
|
||||
counts[v] += 1;
|
||||
}
|
||||
for &c in &counts {
|
||||
// "Crypto RNG counts far from uniform: {c}"
|
||||
assert!((c as isize - 100).abs() < 50);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_normal_distribution() {
|
||||
let mut rng = CryptoRng::new();
|
||||
let mean = 0.0;
|
||||
let sd = 1.0;
|
||||
let n = 2000;
|
||||
let mut sum = 0.0;
|
||||
let mut sum_sq = 0.0;
|
||||
for _ in 0..n {
|
||||
let val = rng.normal(mean, sd);
|
||||
sum += val;
|
||||
sum_sq += val * val;
|
||||
}
|
||||
let sample_mean = sum / n as f64;
|
||||
let sample_var = sum_sq / n as f64 - sample_mean * sample_mean;
|
||||
assert!(sample_mean.abs() < 0.1);
|
||||
assert!((sample_var - 1.0).abs() < 0.2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_instances_different_values() {
|
||||
let mut a = CryptoRng::new();
|
||||
let mut b = CryptoRng::new();
|
||||
let va = a.next_u64();
|
||||
let vb = b.next_u64();
|
||||
assert_ne!(va, vb);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_rng_helper_function() {
|
||||
let mut rng = crypto_rng();
|
||||
let _ = rng.next_u64();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_normal_zero_sd() {
|
||||
let mut rng = CryptoRng::new();
|
||||
for _ in 0..5 {
|
||||
let v = rng.normal(10.0, 0.0);
|
||||
assert_eq!(v, 10.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_shuffle_empty_slice() {
|
||||
use crate::random::SliceRandom;
|
||||
let mut rng = CryptoRng::new();
|
||||
let mut arr: [u8; 0] = [];
|
||||
arr.shuffle(&mut rng);
|
||||
assert!(arr.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_chi_square_uniform() {
|
||||
let mut rng = CryptoRng::new();
|
||||
let mut counts = [0usize; 10];
|
||||
let samples = 10000;
|
||||
for _ in 0..samples {
|
||||
let v = rng.random_range(0..10usize);
|
||||
counts[v] += 1;
|
||||
}
|
||||
let expected = samples as f64 / 10.0;
|
||||
let chi2: f64 = counts
|
||||
.iter()
|
||||
.map(|&c| {
|
||||
let diff = c as f64 - expected;
|
||||
diff * diff / expected
|
||||
})
|
||||
.sum();
|
||||
assert!(chi2 < 40.0, "chi-square statistic too high: {chi2}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_monobit() {
|
||||
let mut rng = CryptoRng::new();
|
||||
let mut ones = 0usize;
|
||||
let samples = 1000;
|
||||
for _ in 0..samples {
|
||||
ones += rng.next_u64().count_ones() as usize;
|
||||
}
|
||||
let total_bits = samples * 64;
|
||||
let ratio = ones as f64 / total_bits as f64;
|
||||
// "bit ratio far from 0.5: {ratio}"
|
||||
assert!((ratio - 0.5).abs() < 0.02);
|
||||
}
|
||||
}
|
||||
29
src/random/mod.rs
Normal file
29
src/random/mod.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
//! Random number generation utilities.
|
||||
//!
|
||||
//! Provides both a simple pseudo-random generator [`Prng`](crate::random::Prng) and a
|
||||
//! cryptographically secure alternative [`CryptoRng`](crate::random::CryptoRng). The
|
||||
//! [`SliceRandom`](crate::random::SliceRandom) trait offers shuffling of slices using any RNG
|
||||
//! implementing [`Rng`](crate::random::Rng).
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::random::{rng, SliceRandom};
|
||||
//!
|
||||
//! let mut rng = rng();
|
||||
//! let mut data = [1, 2, 3, 4];
|
||||
//! data.shuffle(&mut rng);
|
||||
//! assert_eq!(data.len(), 4);
|
||||
//! ```
|
||||
pub mod crypto;
|
||||
pub mod prng;
|
||||
pub mod random_core;
|
||||
pub mod seq;
|
||||
|
||||
pub use crypto::{crypto_rng, CryptoRng};
|
||||
pub use prng::{rng, Prng};
|
||||
pub use random_core::{RangeSample, Rng};
|
||||
pub use seq::SliceRandom;
|
||||
|
||||
pub mod prelude {
|
||||
pub use super::seq::SliceRandom;
|
||||
pub use super::{crypto_rng, rng, CryptoRng, Prng, RangeSample, Rng};
|
||||
}
|
||||
235
src/random/prng.rs
Normal file
235
src/random/prng.rs
Normal file
@@ -0,0 +1,235 @@
|
||||
//! A tiny XorShift64-based pseudo random number generator.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::random::{rng, Rng};
|
||||
//! let mut rng = rng();
|
||||
//! let x = rng.next_u64();
|
||||
//! assert!(x >= 0);
|
||||
//! ```
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::random::Rng;
|
||||
|
||||
/// Simple XorShift64-based pseudo random number generator.
|
||||
#[derive(Clone)]
|
||||
pub struct Prng {
|
||||
state: u64,
|
||||
}
|
||||
|
||||
impl Prng {
|
||||
/// Create a new generator from the given seed.
|
||||
pub fn new(seed: u64) -> Self {
|
||||
Self { state: seed }
|
||||
}
|
||||
|
||||
/// Create a generator seeded from the current time.
|
||||
pub fn from_entropy() -> Self {
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos() as u64;
|
||||
Self::new(nanos)
|
||||
}
|
||||
}
|
||||
|
||||
impl Rng for Prng {
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
let mut x = self.state;
|
||||
x ^= x << 13;
|
||||
x ^= x >> 7;
|
||||
x ^= x << 17;
|
||||
self.state = x;
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience constructor using system entropy.
|
||||
pub fn rng() -> Prng {
|
||||
Prng::from_entropy()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::random::Rng;
|
||||
|
||||
#[test]
|
||||
fn test_prng_determinism() {
|
||||
let mut a = Prng::new(42);
|
||||
let mut b = Prng::new(42);
|
||||
for _ in 0..5 {
|
||||
assert_eq!(a.next_u64(), b.next_u64());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_range_f64() {
|
||||
let mut rng = Prng::new(1);
|
||||
for _ in 0..10 {
|
||||
let v = rng.random_range(-1.0..1.0);
|
||||
assert!(v >= -1.0 && v < 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_range_usize() {
|
||||
let mut rng = Prng::new(9);
|
||||
for _ in 0..100 {
|
||||
let v = rng.random_range(10..20);
|
||||
assert!(v >= 10 && v < 20);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gen_bool_balance() {
|
||||
let mut rng = Prng::new(123);
|
||||
let mut trues = 0;
|
||||
for _ in 0..1000 {
|
||||
if rng.gen_bool() {
|
||||
trues += 1;
|
||||
}
|
||||
}
|
||||
let ratio = trues as f64 / 1000.0;
|
||||
assert!(ratio > 0.4 && ratio < 0.6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_distribution() {
|
||||
let mut rng = Prng::new(7);
|
||||
let mut sum = 0.0;
|
||||
let mut sum_sq = 0.0;
|
||||
let mean = 5.0;
|
||||
let sd = 2.0;
|
||||
let n = 5000;
|
||||
for _ in 0..n {
|
||||
let val = rng.normal(mean, sd);
|
||||
sum += val;
|
||||
sum_sq += val * val;
|
||||
}
|
||||
let sample_mean = sum / n as f64;
|
||||
let sample_var = sum_sq / n as f64 - sample_mean * sample_mean;
|
||||
assert!((sample_mean - mean).abs() < 0.1);
|
||||
assert!((sample_var - sd * sd).abs() < 0.2 * sd * sd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prng_from_entropy_unique() {
|
||||
use std::{collections::HashSet, thread, time::Duration};
|
||||
let mut seen = HashSet::new();
|
||||
for _ in 0..5 {
|
||||
let mut rng = Prng::from_entropy();
|
||||
seen.insert(rng.next_u64());
|
||||
thread::sleep(Duration::from_micros(1));
|
||||
}
|
||||
assert!(seen.len() > 1, "Entropy seeds produced identical outputs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prng_uniform_distribution() {
|
||||
let mut rng = Prng::new(12345);
|
||||
let mut counts = [0usize; 10];
|
||||
for _ in 0..10000 {
|
||||
let v = rng.random_range(0..10usize);
|
||||
counts[v] += 1;
|
||||
}
|
||||
for &c in &counts {
|
||||
// "PRNG counts far from uniform: {c}"
|
||||
assert!((c as isize - 1000).abs() < 150);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prng_different_seeds_different_output() {
|
||||
let mut a = Prng::new(1);
|
||||
let mut b = Prng::new(2);
|
||||
let va = a.next_u64();
|
||||
let vb = b.next_u64();
|
||||
assert_ne!(va, vb);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prng_gen_bool_varies() {
|
||||
let mut rng = Prng::new(99);
|
||||
let mut seen_true = false;
|
||||
let mut seen_false = false;
|
||||
for _ in 0..100 {
|
||||
if rng.gen_bool() {
|
||||
seen_true = true;
|
||||
} else {
|
||||
seen_false = true;
|
||||
}
|
||||
}
|
||||
assert!(seen_true && seen_false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_range_single_usize() {
|
||||
let mut rng = Prng::new(42);
|
||||
for _ in 0..10 {
|
||||
let v = rng.random_range(5..6);
|
||||
assert_eq!(v, 5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_range_single_f64() {
|
||||
let mut rng = Prng::new(42);
|
||||
for _ in 0..10 {
|
||||
let v = rng.random_range(1.234..1.235);
|
||||
assert!(v >= 1.234 && v < 1.235);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prng_normal_zero_sd() {
|
||||
let mut rng = Prng::new(7);
|
||||
for _ in 0..5 {
|
||||
let v = rng.normal(3.0, 0.0);
|
||||
assert_eq!(v, 3.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_range_extreme_usize() {
|
||||
let mut rng = Prng::new(5);
|
||||
for _ in 0..10 {
|
||||
let v = rng.random_range(0..usize::MAX);
|
||||
assert!(v < usize::MAX);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prng_chi_square_uniform() {
|
||||
let mut rng = Prng::new(12345);
|
||||
let mut counts = [0usize; 10];
|
||||
let samples = 10000;
|
||||
for _ in 0..samples {
|
||||
let v = rng.random_range(0..10usize);
|
||||
counts[v] += 1;
|
||||
}
|
||||
let expected = samples as f64 / 10.0;
|
||||
let chi2: f64 = counts
|
||||
.iter()
|
||||
.map(|&c| {
|
||||
let diff = c as f64 - expected;
|
||||
diff * diff / expected
|
||||
})
|
||||
.sum();
|
||||
// "chi-square statistic too high: {chi2}"
|
||||
assert!(chi2 < 20.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prng_monobit() {
|
||||
let mut rng = Prng::new(42);
|
||||
let mut ones = 0usize;
|
||||
let samples = 1000;
|
||||
for _ in 0..samples {
|
||||
ones += rng.next_u64().count_ones() as usize;
|
||||
}
|
||||
let total_bits = samples * 64;
|
||||
let ratio = ones as f64 / total_bits as f64;
|
||||
// "bit ratio far from 0.5: {ratio}"
|
||||
assert!((ratio - 0.5).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
106
src/random/random_core.rs
Normal file
106
src/random/random_core.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
//! Core traits for random number generators and sampling ranges.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::random::{rng, Rng};
|
||||
//! let mut r = rng();
|
||||
//! let value: f64 = r.random_range(0.0..1.0);
|
||||
//! assert!(value >= 0.0 && value < 1.0);
|
||||
//! ```
|
||||
use std::f64::consts::PI;
|
||||
use std::ops::Range;
|
||||
|
||||
/// Trait implemented by random number generators.
|
||||
pub trait Rng {
|
||||
/// Generate the next random `u64` value.
|
||||
fn next_u64(&mut self) -> u64;
|
||||
|
||||
/// Generate a value uniformly in the given range.
|
||||
fn random_range<T>(&mut self, range: Range<T>) -> T
|
||||
where
|
||||
T: RangeSample,
|
||||
{
|
||||
T::from_u64(self.next_u64(), &range)
|
||||
}
|
||||
|
||||
/// Generate a boolean with probability 0.5 of being `true`.
|
||||
fn gen_bool(&mut self) -> bool {
|
||||
self.random_range(0..2usize) == 1
|
||||
}
|
||||
|
||||
/// Sample from a normal distribution using the Box-Muller transform.
|
||||
fn normal(&mut self, mean: f64, sd: f64) -> f64 {
|
||||
let u1 = self.random_range(0.0..1.0);
|
||||
let u2 = self.random_range(0.0..1.0);
|
||||
mean + sd * (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
|
||||
}
|
||||
}
|
||||
|
||||
/// Conversion from a raw `u64` into a type within a range.
|
||||
pub trait RangeSample: Sized {
|
||||
fn from_u64(value: u64, range: &Range<Self>) -> Self;
|
||||
}
|
||||
|
||||
impl RangeSample for usize {
|
||||
fn from_u64(value: u64, range: &Range<Self>) -> Self {
|
||||
let span = range.end - range.start;
|
||||
(value as usize % span) + range.start
|
||||
}
|
||||
}
|
||||
|
||||
impl RangeSample for f64 {
|
||||
fn from_u64(value: u64, range: &Range<Self>) -> Self {
|
||||
let span = range.end - range.start;
|
||||
range.start + (value as f64 / u64::MAX as f64) * span
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_range_sample_usize_boundary() {
|
||||
assert_eq!(<usize as RangeSample>::from_u64(0, &(0..1)), 0);
|
||||
assert_eq!(<usize as RangeSample>::from_u64(u64::MAX, &(0..1)), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_sample_f64_boundary() {
|
||||
let v0 = <f64 as RangeSample>::from_u64(0, &(0.0..1.0));
|
||||
let vmax = <f64 as RangeSample>::from_u64(u64::MAX, &(0.0..1.0));
|
||||
assert!(v0 >= 0.0 && v0 < 1.0);
|
||||
assert!(vmax > 0.999999999999 && vmax <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_sample_usize_varied() {
|
||||
for i in 0..5 {
|
||||
let v = <usize as RangeSample>::from_u64(i, &(10..15));
|
||||
assert!(v >= 10 && v < 15);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_sample_f64_span() {
|
||||
for val in [0, u64::MAX / 2, u64::MAX] {
|
||||
let f = <f64 as RangeSample>::from_u64(val, &(2.0..4.0));
|
||||
assert!(f >= 2.0 && f <= 4.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_sample_usize_single_value() {
|
||||
for val in [0, 1, u64::MAX] {
|
||||
let n = <usize as RangeSample>::from_u64(val, &(5..6));
|
||||
assert_eq!(n, 5);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range_sample_f64_negative_range() {
|
||||
for val in [0, u64::MAX / 3, u64::MAX] {
|
||||
let f = <f64 as RangeSample>::from_u64(val, &(-2.0..2.0));
|
||||
assert!(f >= -2.0 && f <= 2.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
113
src/random/seq.rs
Normal file
113
src/random/seq.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
//! Extensions for shuffling slices with a random number generator.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::random::{rng, SliceRandom};
|
||||
//! let mut data = [1, 2, 3];
|
||||
//! data.shuffle(&mut rng());
|
||||
//! assert_eq!(data.len(), 3);
|
||||
//! ```
|
||||
use crate::random::Rng;
|
||||
|
||||
/// Trait for randomizing slices.
|
||||
pub trait SliceRandom {
|
||||
/// Shuffle the slice in place using the provided RNG.
|
||||
fn shuffle<R: Rng>(&mut self, rng: &mut R);
|
||||
}
|
||||
|
||||
impl<T> SliceRandom for [T] {
|
||||
fn shuffle<R: Rng>(&mut self, rng: &mut R) {
|
||||
for i in (1..self.len()).rev() {
|
||||
let j = rng.random_range(0..(i + 1));
|
||||
self.swap(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::random::{CryptoRng, Prng};
|
||||
|
||||
#[test]
|
||||
fn test_shuffle_slice() {
|
||||
let mut rng = Prng::new(3);
|
||||
let mut arr = [1, 2, 3, 4, 5];
|
||||
let orig = arr.clone();
|
||||
arr.shuffle(&mut rng);
|
||||
assert_eq!(arr.len(), orig.len());
|
||||
let mut sorted = arr.to_vec();
|
||||
sorted.sort();
|
||||
assert_eq!(sorted, orig.to_vec());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice_shuffle_deterministic_with_prng() {
|
||||
let mut rng1 = Prng::new(11);
|
||||
let mut rng2 = Prng::new(11);
|
||||
let mut a = [1u8, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||
let mut b = a.clone();
|
||||
a.shuffle(&mut rng1);
|
||||
b.shuffle(&mut rng2);
|
||||
assert_eq!(a, b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice_shuffle_crypto_random_changes() {
|
||||
let mut rng1 = CryptoRng::new();
|
||||
let mut rng2 = CryptoRng::new();
|
||||
let orig = [1u8, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||
let mut a = orig.clone();
|
||||
let mut b = orig.clone();
|
||||
a.shuffle(&mut rng1);
|
||||
b.shuffle(&mut rng2);
|
||||
assert!(a != orig || b != orig, "Shuffles did not change order");
|
||||
assert_ne!(a, b, "Two Crypto RNG shuffles produced same order");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffle_single_element_no_change() {
|
||||
let mut rng = Prng::new(1);
|
||||
let mut arr = [42];
|
||||
arr.shuffle(&mut rng);
|
||||
assert_eq!(arr, [42]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_shuffles_different_results() {
|
||||
let mut rng = Prng::new(5);
|
||||
let mut arr1 = [1, 2, 3, 4];
|
||||
let mut arr2 = [1, 2, 3, 4];
|
||||
arr1.shuffle(&mut rng);
|
||||
arr2.shuffle(&mut rng);
|
||||
assert_ne!(arr1, arr2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffle_empty_slice() {
|
||||
let mut rng = Prng::new(1);
|
||||
let mut arr: [i32; 0] = [];
|
||||
arr.shuffle(&mut rng);
|
||||
assert!(arr.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffle_three_uniform() {
|
||||
use std::collections::HashMap;
|
||||
let mut rng = Prng::new(123);
|
||||
let mut counts: HashMap<[u8; 3], usize> = HashMap::new();
|
||||
for _ in 0..6000 {
|
||||
let mut arr = [1u8, 2, 3];
|
||||
arr.shuffle(&mut rng);
|
||||
*counts.entry(arr).or_insert(0) += 1;
|
||||
}
|
||||
let expected = 1000.0;
|
||||
let chi2: f64 = counts
|
||||
.values()
|
||||
.map(|&c| {
|
||||
let diff = c as f64 - expected;
|
||||
diff * diff / expected
|
||||
})
|
||||
.sum();
|
||||
assert!(chi2 < 30.0, "shuffle chi-square too high: {chi2}");
|
||||
}
|
||||
}
|
||||
2410
src/utils/bdates.rs
2410
src/utils/bdates.rs
File diff suppressed because it is too large
Load Diff
1186
src/utils/dateutils/bdates.rs
Normal file
1186
src/utils/dateutils/bdates.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,10 @@
|
||||
//! Generation and manipulation of calendar date sequences.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::utils::dateutils::dates::{DateFreq, DatesList};
|
||||
//! let list = DatesList::new("2024-01-01".into(), "2024-01-03".into(), DateFreq::Daily);
|
||||
//! assert_eq!(list.count().unwrap(), 3);
|
||||
//! ```
|
||||
use chrono::{Datelike, Duration, NaiveDate, Weekday};
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
@@ -5,8 +12,6 @@ use std::hash::Hash;
|
||||
use std::result::Result;
|
||||
use std::str::FromStr;
|
||||
|
||||
// --- Core Enums ---
|
||||
|
||||
/// Represents the frequency at which calendar dates should be generated.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum DateFreq {
|
||||
@@ -124,8 +129,6 @@ impl FromStr for DateFreq {
|
||||
}
|
||||
}
|
||||
|
||||
// --- DatesList Struct ---
|
||||
|
||||
/// Represents a list of calendar dates generated between a start and end date
|
||||
/// at a specified frequency. Provides methods to retrieve the full list,
|
||||
/// count, or dates grouped by period.
|
||||
@@ -164,7 +167,7 @@ enum GroupKey {
|
||||
/// ```rust
|
||||
/// use chrono::NaiveDate;
|
||||
/// use std::error::Error;
|
||||
/// # use rustframe::utils::{DatesList, DateFreq}; // Assuming the crate/module is named 'dates'
|
||||
/// use rustframe::utils::{DatesList, DateFreq};
|
||||
///
|
||||
/// # fn main() -> Result<(), Box<dyn Error>> {
|
||||
/// let start_date = "2023-11-01".to_string(); // Wednesday
|
||||
@@ -340,32 +343,7 @@ impl DatesList {
|
||||
/// Returns an error if the start or end date strings cannot be parsed.
|
||||
pub fn groups(&self) -> Result<Vec<Vec<NaiveDate>>, Box<dyn Error>> {
|
||||
let dates = self.list()?;
|
||||
let mut groups: HashMap<GroupKey, Vec<NaiveDate>> = HashMap::new();
|
||||
|
||||
for date in dates {
|
||||
let key = match self.freq {
|
||||
DateFreq::Daily => GroupKey::Daily(date),
|
||||
DateFreq::WeeklyMonday | DateFreq::WeeklyFriday => {
|
||||
let iso_week = date.iso_week();
|
||||
GroupKey::Weekly(iso_week.year(), iso_week.week())
|
||||
}
|
||||
DateFreq::MonthStart | DateFreq::MonthEnd => {
|
||||
GroupKey::Monthly(date.year(), date.month())
|
||||
}
|
||||
DateFreq::QuarterStart | DateFreq::QuarterEnd => {
|
||||
GroupKey::Quarterly(date.year(), month_to_quarter(date.month()))
|
||||
}
|
||||
DateFreq::YearStart | DateFreq::YearEnd => GroupKey::Yearly(date.year()),
|
||||
};
|
||||
groups.entry(key).or_insert_with(Vec::new).push(date);
|
||||
}
|
||||
|
||||
let mut sorted_groups: Vec<(GroupKey, Vec<NaiveDate>)> = groups.into_iter().collect();
|
||||
sorted_groups.sort_by(|(k1, _), (k2, _)| k1.cmp(k2));
|
||||
|
||||
// Dates within groups are already sorted because they came from the sorted `self.list()`.
|
||||
let result_groups = sorted_groups.into_iter().map(|(_, dates)| dates).collect();
|
||||
Ok(result_groups)
|
||||
group_dates_helper(dates, self.freq)
|
||||
}
|
||||
|
||||
/// Returns the start date parsed as a `NaiveDate`.
|
||||
@@ -407,8 +385,6 @@ impl DatesList {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Dates Generator (Iterator) ---
|
||||
|
||||
/// An iterator that generates a sequence of calendar dates based on a start date,
|
||||
/// frequency, and a specified number of periods.
|
||||
///
|
||||
@@ -492,10 +468,10 @@ impl DatesList {
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DatesGenerator {
|
||||
freq: DateFreq,
|
||||
periods_remaining: usize,
|
||||
pub freq: DateFreq,
|
||||
pub periods_remaining: usize,
|
||||
// Stores the *next* date to be yielded by the iterator.
|
||||
next_date_candidate: Option<NaiveDate>,
|
||||
pub next_date_candidate: Option<NaiveDate>,
|
||||
}
|
||||
|
||||
impl DatesGenerator {
|
||||
@@ -561,11 +537,43 @@ impl Iterator for DatesGenerator {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Internal helper functions ---
|
||||
// Internal helper functions
|
||||
|
||||
pub fn group_dates_helper(
|
||||
dates: Vec<NaiveDate>,
|
||||
freq: DateFreq,
|
||||
) -> Result<Vec<Vec<NaiveDate>>, Box<dyn Error + 'static>> {
|
||||
let mut groups: HashMap<GroupKey, Vec<NaiveDate>> = HashMap::new();
|
||||
|
||||
for date in dates {
|
||||
let key = match freq {
|
||||
DateFreq::Daily => GroupKey::Daily(date),
|
||||
DateFreq::WeeklyMonday | DateFreq::WeeklyFriday => {
|
||||
let iso_week = date.iso_week();
|
||||
GroupKey::Weekly(iso_week.year(), iso_week.week())
|
||||
}
|
||||
DateFreq::MonthStart | DateFreq::MonthEnd => {
|
||||
GroupKey::Monthly(date.year(), date.month())
|
||||
}
|
||||
DateFreq::QuarterStart | DateFreq::QuarterEnd => {
|
||||
GroupKey::Quarterly(date.year(), month_to_quarter(date.month()))
|
||||
}
|
||||
DateFreq::YearStart | DateFreq::YearEnd => GroupKey::Yearly(date.year()),
|
||||
};
|
||||
groups.entry(key).or_insert_with(Vec::new).push(date);
|
||||
}
|
||||
|
||||
let mut sorted_groups: Vec<(GroupKey, Vec<NaiveDate>)> = groups.into_iter().collect();
|
||||
sorted_groups.sort_by(|(k1, _), (k2, _)| k1.cmp(k2));
|
||||
|
||||
// Dates within groups are already sorted because they came from the sorted `self.list()`.
|
||||
let result_groups = sorted_groups.into_iter().map(|(_, dates)| dates).collect();
|
||||
Ok(result_groups)
|
||||
}
|
||||
|
||||
/// Generates the flat list of dates for the given range and frequency.
|
||||
/// Assumes the `collect_*` functions return sorted dates.
|
||||
fn get_dates_list_with_freq(
|
||||
pub fn get_dates_list_with_freq(
|
||||
start_date_str: &str,
|
||||
end_date_str: &str,
|
||||
freq: DateFreq,
|
||||
@@ -601,7 +609,7 @@ fn get_dates_list_with_freq(
|
||||
Ok(dates)
|
||||
}
|
||||
|
||||
/* ---------------------- Low-Level Date Collection Functions (Internal) ---------------------- */
|
||||
// Low-Level Date Collection Functions (Internal)
|
||||
// These functions generate dates within a *range* [start_date, end_date]
|
||||
|
||||
/// Returns all calendar days day-by-day within the range.
|
||||
@@ -648,8 +656,13 @@ fn collect_monthly(
|
||||
let mut year = start_date.year();
|
||||
let mut month = start_date.month();
|
||||
|
||||
let next_month =
|
||||
|(yr, mo): (i32, u32)| -> (i32, u32) { if mo == 12 { (yr + 1, 1) } else { (yr, mo + 1) } };
|
||||
let next_month = |(yr, mo): (i32, u32)| -> (i32, u32) {
|
||||
if mo == 12 {
|
||||
(yr + 1, 1)
|
||||
} else {
|
||||
(yr, mo + 1)
|
||||
}
|
||||
};
|
||||
|
||||
loop {
|
||||
let candidate = if want_first_day {
|
||||
@@ -728,6 +741,21 @@ fn collect_quarterly(
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Returns a list of dates between the given start and end dates, inclusive,
|
||||
/// at the specified frequency.
|
||||
/// This function is a convenience wrapper around `get_dates_list_with_freq`.
|
||||
pub fn get_dates_list_with_freq_from_naive_date(
|
||||
start_date: NaiveDate,
|
||||
end_date: NaiveDate,
|
||||
freq: DateFreq,
|
||||
) -> Result<Vec<NaiveDate>, Box<dyn Error>> {
|
||||
get_dates_list_with_freq(
|
||||
&start_date.format("%Y-%m-%d").to_string(),
|
||||
&end_date.format("%Y-%m-%d").to_string(),
|
||||
freq,
|
||||
)
|
||||
}
|
||||
|
||||
/// Return either the first or last calendar day in each year of the range.
|
||||
fn collect_yearly(
|
||||
start_date: NaiveDate,
|
||||
@@ -757,8 +785,6 @@ fn collect_yearly(
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/* ---------------------- Core Date Utility Functions (Internal) ---------------------- */
|
||||
|
||||
/// Given a date and a `target_weekday`, returns the date that is the first
|
||||
/// `target_weekday` on or after the given date.
|
||||
fn move_to_day_of_week_on_or_after(
|
||||
@@ -814,7 +840,7 @@ fn last_day_of_month(year: i32, month: u32) -> Result<NaiveDate, Box<dyn Error>>
|
||||
|
||||
/// Converts a month number (1-12) to a quarter number (1-4).
|
||||
/// Panics if month is invalid (should not happen with valid NaiveDate).
|
||||
fn month_to_quarter(m: u32) -> u32 {
|
||||
pub fn month_to_quarter(m: u32) -> u32 {
|
||||
match m {
|
||||
1..=3 => 1,
|
||||
4..=6 => 2,
|
||||
@@ -873,9 +899,28 @@ fn last_day_of_year(year: i32) -> Result<NaiveDate, Box<dyn Error>> {
|
||||
|
||||
// --- Generator Helper Functions ---
|
||||
|
||||
fn get_first_date_helper(freq: DateFreq) -> fn(i32, u32) -> Result<NaiveDate, Box<dyn Error>> {
|
||||
if matches!(
|
||||
freq,
|
||||
DateFreq::Daily | DateFreq::WeeklyMonday | DateFreq::WeeklyFriday
|
||||
) {
|
||||
panic!("Daily, WeeklyMonday, and WeeklyFriday frequencies are not supported here");
|
||||
}
|
||||
|
||||
match freq {
|
||||
DateFreq::MonthStart => first_day_of_month,
|
||||
DateFreq::MonthEnd => last_day_of_month,
|
||||
DateFreq::QuarterStart => first_day_of_quarter,
|
||||
DateFreq::QuarterEnd => last_day_of_quarter,
|
||||
DateFreq::YearStart => |year, _| first_day_of_year(year),
|
||||
DateFreq::YearEnd => |year, _| last_day_of_year(year),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Finds the *first* valid date according to the frequency,
|
||||
/// starting the search *on or after* the given `start_date`.
|
||||
fn find_first_date_on_or_after(
|
||||
pub fn find_first_date_on_or_after(
|
||||
start_date: NaiveDate,
|
||||
freq: DateFreq,
|
||||
) -> Result<NaiveDate, Box<dyn Error>> {
|
||||
@@ -883,69 +928,42 @@ fn find_first_date_on_or_after(
|
||||
DateFreq::Daily => Ok(start_date), // The first daily date is the start date itself
|
||||
DateFreq::WeeklyMonday => move_to_day_of_week_on_or_after(start_date, Weekday::Mon),
|
||||
DateFreq::WeeklyFriday => move_to_day_of_week_on_or_after(start_date, Weekday::Fri),
|
||||
DateFreq::MonthStart => {
|
||||
let mut candidate = first_day_of_month(start_date.year(), start_date.month())?;
|
||||
|
||||
DateFreq::MonthStart | DateFreq::MonthEnd => {
|
||||
// let mut candidate = first_day_of_month(start_date.year(), start_date.month())?;
|
||||
let get_cand_func = get_first_date_helper(freq);
|
||||
let mut candidate = get_cand_func(start_date.year(), start_date.month())?;
|
||||
if candidate < start_date {
|
||||
let (next_y, next_m) = if start_date.month() == 12 {
|
||||
(start_date.year().checked_add(1).ok_or("Year overflow")?, 1)
|
||||
} else {
|
||||
(start_date.year(), start_date.month() + 1)
|
||||
};
|
||||
candidate = first_day_of_month(next_y, next_m)?;
|
||||
candidate = get_cand_func(next_y, next_m)?;
|
||||
}
|
||||
Ok(candidate)
|
||||
}
|
||||
DateFreq::MonthEnd => {
|
||||
let mut candidate = last_day_of_month(start_date.year(), start_date.month())?;
|
||||
if candidate < start_date {
|
||||
let (next_y, next_m) = if start_date.month() == 12 {
|
||||
(start_date.year().checked_add(1).ok_or("Year overflow")?, 1)
|
||||
} else {
|
||||
(start_date.year(), start_date.month() + 1)
|
||||
};
|
||||
candidate = last_day_of_month(next_y, next_m)?;
|
||||
}
|
||||
Ok(candidate)
|
||||
}
|
||||
DateFreq::QuarterStart => {
|
||||
DateFreq::QuarterStart | DateFreq::QuarterEnd => {
|
||||
let current_q = month_to_quarter(start_date.month());
|
||||
let mut candidate = first_day_of_quarter(start_date.year(), current_q)?;
|
||||
let get_cand_func = get_first_date_helper(freq);
|
||||
let mut candidate = get_cand_func(start_date.year(), current_q)?;
|
||||
if candidate < start_date {
|
||||
let (next_y, next_q) = if current_q == 4 {
|
||||
(start_date.year().checked_add(1).ok_or("Year overflow")?, 1)
|
||||
} else {
|
||||
(start_date.year(), current_q + 1)
|
||||
};
|
||||
candidate = first_day_of_quarter(next_y, next_q)?;
|
||||
candidate = get_cand_func(next_y, next_q)?;
|
||||
}
|
||||
Ok(candidate)
|
||||
}
|
||||
DateFreq::QuarterEnd => {
|
||||
let current_q = month_to_quarter(start_date.month());
|
||||
let mut candidate = last_day_of_quarter(start_date.year(), current_q)?;
|
||||
if candidate < start_date {
|
||||
let (next_y, next_q) = if current_q == 4 {
|
||||
(start_date.year().checked_add(1).ok_or("Year overflow")?, 1)
|
||||
} else {
|
||||
(start_date.year(), current_q + 1)
|
||||
};
|
||||
candidate = last_day_of_quarter(next_y, next_q)?;
|
||||
}
|
||||
Ok(candidate)
|
||||
}
|
||||
DateFreq::YearStart => {
|
||||
let mut candidate = first_day_of_year(start_date.year())?;
|
||||
|
||||
DateFreq::YearStart | DateFreq::YearEnd => {
|
||||
let get_cand_func = get_first_date_helper(freq);
|
||||
let mut candidate = get_cand_func(start_date.year(), 0)?;
|
||||
if candidate < start_date {
|
||||
candidate =
|
||||
first_day_of_year(start_date.year().checked_add(1).ok_or("Year overflow")?)?;
|
||||
}
|
||||
Ok(candidate)
|
||||
}
|
||||
DateFreq::YearEnd => {
|
||||
let mut candidate = last_day_of_year(start_date.year())?;
|
||||
if candidate < start_date {
|
||||
candidate =
|
||||
last_day_of_year(start_date.year().checked_add(1).ok_or("Year overflow")?)?;
|
||||
get_cand_func(start_date.year().checked_add(1).ok_or("Year overflow")?, 0)?;
|
||||
}
|
||||
Ok(candidate)
|
||||
}
|
||||
@@ -954,7 +972,10 @@ fn find_first_date_on_or_after(
|
||||
|
||||
/// Finds the *next* valid date according to the frequency,
|
||||
/// given the `current_date` (which is assumed to be a valid date previously generated).
|
||||
fn find_next_date(current_date: NaiveDate, freq: DateFreq) -> Result<NaiveDate, Box<dyn Error>> {
|
||||
pub fn find_next_date(
|
||||
current_date: NaiveDate,
|
||||
freq: DateFreq,
|
||||
) -> Result<NaiveDate, Box<dyn Error>> {
|
||||
match freq {
|
||||
DateFreq::Daily => current_date
|
||||
.succ_opt()
|
||||
@@ -962,7 +983,8 @@ fn find_next_date(current_date: NaiveDate, freq: DateFreq) -> Result<NaiveDate,
|
||||
DateFreq::WeeklyMonday | DateFreq::WeeklyFriday => current_date
|
||||
.checked_add_signed(Duration::days(7))
|
||||
.ok_or_else(|| "Date overflow adding 7 days".into()),
|
||||
DateFreq::MonthStart => {
|
||||
DateFreq::MonthStart | DateFreq::MonthEnd => {
|
||||
let get_cand_func = get_first_date_helper(freq);
|
||||
let (next_y, next_m) = if current_date.month() == 12 {
|
||||
(
|
||||
current_date.year().checked_add(1).ok_or("Year overflow")?,
|
||||
@@ -971,21 +993,11 @@ fn find_next_date(current_date: NaiveDate, freq: DateFreq) -> Result<NaiveDate,
|
||||
} else {
|
||||
(current_date.year(), current_date.month() + 1)
|
||||
};
|
||||
first_day_of_month(next_y, next_m)
|
||||
get_cand_func(next_y, next_m)
|
||||
}
|
||||
DateFreq::MonthEnd => {
|
||||
let (next_y, next_m) = if current_date.month() == 12 {
|
||||
(
|
||||
current_date.year().checked_add(1).ok_or("Year overflow")?,
|
||||
1,
|
||||
)
|
||||
} else {
|
||||
(current_date.year(), current_date.month() + 1)
|
||||
};
|
||||
last_day_of_month(next_y, next_m)
|
||||
}
|
||||
DateFreq::QuarterStart => {
|
||||
DateFreq::QuarterStart | DateFreq::QuarterEnd => {
|
||||
let current_q = month_to_quarter(current_date.month());
|
||||
let get_cand_func = get_first_date_helper(freq);
|
||||
let (next_y, next_q) = if current_q == 4 {
|
||||
(
|
||||
current_date.year().checked_add(1).ok_or("Year overflow")?,
|
||||
@@ -994,25 +1006,14 @@ fn find_next_date(current_date: NaiveDate, freq: DateFreq) -> Result<NaiveDate,
|
||||
} else {
|
||||
(current_date.year(), current_q + 1)
|
||||
};
|
||||
first_day_of_quarter(next_y, next_q)
|
||||
get_cand_func(next_y, next_q)
|
||||
}
|
||||
DateFreq::QuarterEnd => {
|
||||
let current_q = month_to_quarter(current_date.month());
|
||||
let (next_y, next_q) = if current_q == 4 {
|
||||
(
|
||||
DateFreq::YearStart | DateFreq::YearEnd => {
|
||||
let get_cand_func = get_first_date_helper(freq);
|
||||
get_cand_func(
|
||||
current_date.year().checked_add(1).ok_or("Year overflow")?,
|
||||
1,
|
||||
0,
|
||||
)
|
||||
} else {
|
||||
(current_date.year(), current_q + 1)
|
||||
};
|
||||
last_day_of_quarter(next_y, next_q)
|
||||
}
|
||||
DateFreq::YearStart => {
|
||||
first_day_of_year(current_date.year().checked_add(1).ok_or("Year overflow")?)
|
||||
}
|
||||
DateFreq::YearEnd => {
|
||||
last_day_of_year(current_date.year().checked_add(1).ok_or("Year overflow")?)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1475,8 +1476,6 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// --- Tests for internal helper functions ---
|
||||
|
||||
#[test]
|
||||
fn test_move_to_day_of_week_on_or_after() -> Result<(), Box<dyn Error>> {
|
||||
assert_eq!(
|
||||
@@ -1508,7 +1507,8 @@ mod tests {
|
||||
// And trying to move *past* it should fail
|
||||
let day_before = NaiveDate::MAX - Duration::days(1);
|
||||
let target_day_after = NaiveDate::MAX.weekday().succ(); // Day after MAX's weekday
|
||||
assert!(move_to_day_of_week_on_or_after(day_before, target_day_after).is_err()); // Moving past MAX fails
|
||||
assert!(move_to_day_of_week_on_or_after(day_before, target_day_after).is_err());
|
||||
// Moving past MAX fails
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -1527,12 +1527,15 @@ mod tests {
|
||||
fn test_days_in_month() -> Result<(), Box<dyn Error>> {
|
||||
assert_eq!(days_in_month(2023, 1)?, 31);
|
||||
assert_eq!(days_in_month(2023, 2)?, 28);
|
||||
assert_eq!(days_in_month(2024, 2)?, 29); // Leap
|
||||
// Leap
|
||||
assert_eq!(days_in_month(2024, 2)?, 29);
|
||||
assert_eq!(days_in_month(2023, 4)?, 30);
|
||||
assert_eq!(days_in_month(2023, 12)?, 31);
|
||||
assert!(days_in_month(2023, 0).is_err()); // Invalid month 0
|
||||
assert!(days_in_month(2023, 13).is_err()); // Invalid month 13
|
||||
// Invalid month 0
|
||||
assert!(days_in_month(2023, 0).is_err());
|
||||
// Invalid month 13
|
||||
// Test near max date year overflow - Use MAX.year()
|
||||
assert!(days_in_month(2023, 13).is_err());
|
||||
assert!(days_in_month(NaiveDate::MAX.year(), 12).is_err());
|
||||
Ok(())
|
||||
}
|
||||
@@ -1542,9 +1545,12 @@ mod tests {
|
||||
assert_eq!(last_day_of_month(2023, 11)?, date(2023, 11, 30));
|
||||
assert_eq!(last_day_of_month(2024, 2)?, date(2024, 2, 29)); // Leap
|
||||
assert_eq!(last_day_of_month(2023, 12)?, date(2023, 12, 31));
|
||||
assert!(last_day_of_month(2023, 0).is_err()); // Invalid month 0
|
||||
assert!(last_day_of_month(2023, 13).is_err()); // Invalid month 13
|
||||
// Invalid month 0
|
||||
assert!(last_day_of_month(2023, 0).is_err());
|
||||
// Invalid month 13
|
||||
// Test near max date year overflow - use MAX.year()
|
||||
assert!(last_day_of_month(2023, 13).is_err());
|
||||
|
||||
assert!(last_day_of_month(NaiveDate::MAX.year(), 12).is_err());
|
||||
Ok(())
|
||||
}
|
||||
@@ -1588,7 +1594,8 @@ mod tests {
|
||||
assert_eq!(first_day_of_quarter(2023, 2)?, date(2023, 4, 1));
|
||||
assert_eq!(first_day_of_quarter(2023, 3)?, date(2023, 7, 1));
|
||||
assert_eq!(first_day_of_quarter(2023, 4)?, date(2023, 10, 1));
|
||||
assert!(first_day_of_quarter(2023, 5).is_err()); // Invalid quarter
|
||||
// Invalid quarter
|
||||
assert!(first_day_of_quarter(2023, 5).is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1608,9 +1615,11 @@ mod tests {
|
||||
assert_eq!(last_day_of_quarter(2023, 2)?, date(2023, 6, 30));
|
||||
assert_eq!(last_day_of_quarter(2023, 3)?, date(2023, 9, 30));
|
||||
assert_eq!(last_day_of_quarter(2023, 4)?, date(2023, 12, 31));
|
||||
assert_eq!(last_day_of_quarter(2024, 1)?, date(2024, 3, 31)); // Leap year doesn't affect March end
|
||||
assert!(last_day_of_quarter(2023, 5).is_err()); // Invalid quarter
|
||||
// Leap year doesn't affect March end
|
||||
assert_eq!(last_day_of_quarter(2024, 1)?, date(2024, 3, 31));
|
||||
// Invalid quarter
|
||||
// Test overflow propagation - use MAX.year()
|
||||
assert!(last_day_of_quarter(2023, 5).is_err());
|
||||
assert!(last_day_of_quarter(NaiveDate::MAX.year(), 4).is_err());
|
||||
Ok(())
|
||||
}
|
||||
@@ -1627,16 +1636,13 @@ mod tests {
|
||||
#[test]
|
||||
fn test_last_day_of_year() -> Result<(), Box<dyn Error>> {
|
||||
assert_eq!(last_day_of_year(2023)?, date(2023, 12, 31));
|
||||
assert_eq!(last_day_of_year(2024)?, date(2024, 12, 31)); // Leap year doesn't affect Dec 31st existence
|
||||
// Leap year doesn't affect Dec 31st existence
|
||||
// Test MAX year - should be okay since MAX is Dec 31
|
||||
assert_eq!(last_day_of_year(2024)?, date(2024, 12, 31));
|
||||
assert_eq!(last_day_of_year(NaiveDate::MAX.year())?, NaiveDate::MAX);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Overflow tests for collect_* removed as they were misleading
|
||||
|
||||
// --- Tests for Generator Helper Functions ---
|
||||
|
||||
#[test]
|
||||
fn test_find_first_date_on_or_after() -> Result<(), Box<dyn Error>> {
|
||||
// Daily
|
||||
@@ -1644,10 +1650,11 @@ mod tests {
|
||||
find_first_date_on_or_after(date(2023, 11, 8), DateFreq::Daily)?,
|
||||
date(2023, 11, 8)
|
||||
);
|
||||
// Sat -> Sat
|
||||
assert_eq!(
|
||||
find_first_date_on_or_after(date(2023, 11, 11), DateFreq::Daily)?,
|
||||
date(2023, 11, 11)
|
||||
); // Sat -> Sat
|
||||
);
|
||||
|
||||
// Weekly Mon
|
||||
assert_eq!(
|
||||
@@ -1658,10 +1665,11 @@ mod tests {
|
||||
find_first_date_on_or_after(date(2023, 11, 13), DateFreq::WeeklyMonday)?,
|
||||
date(2023, 11, 13)
|
||||
);
|
||||
// Sun -> Mon
|
||||
assert_eq!(
|
||||
find_first_date_on_or_after(date(2023, 11, 12), DateFreq::WeeklyMonday)?,
|
||||
date(2023, 11, 13)
|
||||
); // Sun -> Mon
|
||||
);
|
||||
|
||||
// Weekly Fri
|
||||
assert_eq!(
|
||||
@@ -1690,10 +1698,11 @@ mod tests {
|
||||
find_first_date_on_or_after(date(2023, 12, 15), DateFreq::MonthStart)?,
|
||||
date(2024, 1, 1)
|
||||
);
|
||||
// Oct 1 -> Oct 1
|
||||
assert_eq!(
|
||||
find_first_date_on_or_after(date(2023, 10, 1), DateFreq::MonthStart)?,
|
||||
date(2023, 10, 1)
|
||||
); // Oct 1 -> Oct 1
|
||||
);
|
||||
|
||||
// Month End
|
||||
assert_eq!(
|
||||
@@ -1704,18 +1713,21 @@ mod tests {
|
||||
find_first_date_on_or_after(date(2023, 11, 15), DateFreq::MonthEnd)?,
|
||||
date(2023, 11, 30)
|
||||
);
|
||||
// Dec 31 -> Dec 31
|
||||
assert_eq!(
|
||||
find_first_date_on_or_after(date(2023, 12, 31), DateFreq::MonthEnd)?,
|
||||
date(2023, 12, 31)
|
||||
); // Dec 31 -> Dec 31
|
||||
);
|
||||
// Mid Feb (Leap) -> Feb 29
|
||||
assert_eq!(
|
||||
find_first_date_on_or_after(date(2024, 2, 15), DateFreq::MonthEnd)?,
|
||||
date(2024, 2, 29)
|
||||
); // Mid Feb (Leap) -> Feb 29
|
||||
);
|
||||
// Feb 29 -> Feb 29
|
||||
assert_eq!(
|
||||
find_first_date_on_or_after(date(2024, 2, 29), DateFreq::MonthEnd)?,
|
||||
date(2024, 2, 29)
|
||||
); // Feb 29 -> Feb 29
|
||||
);
|
||||
|
||||
// Quarter Start
|
||||
assert_eq!(
|
||||
@@ -1920,13 +1932,11 @@ mod tests {
|
||||
assert!(find_next_date(NaiveDate::MAX, DateFreq::MonthEnd).is_err());
|
||||
|
||||
// Test finding next quarter start after Q4 MAX_YEAR -> Q1 (MAX_YEAR+1) (fail)
|
||||
assert!(
|
||||
find_next_date(
|
||||
assert!(find_next_date(
|
||||
first_day_of_quarter(NaiveDate::MAX.year(), 4)?,
|
||||
DateFreq::QuarterStart
|
||||
)
|
||||
.is_err()
|
||||
);
|
||||
.is_err());
|
||||
|
||||
// Test finding next quarter end after Q3 MAX_YEAR -> Q4 MAX_YEAR (fails because last_day_of_quarter(MAX, 4) fails)
|
||||
let q3_end_max_year = last_day_of_quarter(NaiveDate::MAX.year(), 3)?;
|
||||
@@ -1937,22 +1947,18 @@ mod tests {
|
||||
assert!(find_next_date(NaiveDate::MAX, DateFreq::QuarterEnd).is_err());
|
||||
|
||||
// Test finding next year start after Jan 1 MAX_YEAR -> Jan 1 (MAX_YEAR+1) (fail)
|
||||
assert!(
|
||||
find_next_date(
|
||||
assert!(find_next_date(
|
||||
first_day_of_year(NaiveDate::MAX.year())?,
|
||||
DateFreq::YearStart
|
||||
)
|
||||
.is_err()
|
||||
);
|
||||
.is_err());
|
||||
|
||||
// Test finding next year end after Dec 31 (MAX_YEAR-1) -> Dec 31 MAX_YEAR (ok)
|
||||
assert!(
|
||||
find_next_date(
|
||||
assert!(find_next_date(
|
||||
last_day_of_year(NaiveDate::MAX.year() - 1)?,
|
||||
DateFreq::YearEnd
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
.is_ok());
|
||||
|
||||
// Test finding next year end after Dec 31 MAX_YEAR -> Dec 31 (MAX_YEAR+1) (fail)
|
||||
assert!(
|
||||
@@ -2152,32 +2158,15 @@ mod tests {
|
||||
// find_first returns start_date (YE MAX-1)
|
||||
assert_eq!(generator.next(), Some(start_date));
|
||||
// find_next finds YE(MAX)
|
||||
assert_eq!(generator.next(), Some(last_day_of_year(start_year)?)); // Should be MAX
|
||||
assert_eq!(generator.next(), Some(last_day_of_year(start_year)?));
|
||||
// Should be MAX
|
||||
// find_next tries YE(MAX+1) - this call to find_next_date fails internally
|
||||
assert_eq!(generator.next(), None); // Returns None because internal find_next_date failed
|
||||
|
||||
// Check internal state after the call that returned None
|
||||
// When Some(YE MAX) was returned, periods_remaining became 1.
|
||||
// The next call enters the match, calls find_next_date (fails -> .ok() is None),
|
||||
// sets next_date_candidate=None, decrements periods_remaining to 0, returns Some(YE MAX).
|
||||
// --> NO, the code was: set candidate=find().ok(), THEN decrement.
|
||||
// Let's revisit Iterator::next logic:
|
||||
// 1. periods_remaining = 1, next_date_candidate = Some(YE MAX)
|
||||
// 2. Enter match arm
|
||||
// 3. find_next_date(YE MAX, YE) -> Err
|
||||
// 4. self.next_date_candidate = Err.ok() -> None
|
||||
// 5. self.periods_remaining -= 1 -> becomes 0
|
||||
// 6. return Some(YE MAX) <-- This was the bug in my reasoning. It returns the *current* date first.
|
||||
// State after returning Some(YE MAX): periods_remaining = 0, next_date_candidate = None
|
||||
// Next call to generator.next():
|
||||
// 1. periods_remaining = 0
|
||||
// 2. Enter the `_` arm of the match
|
||||
// 3. self.periods_remaining = 0 (no change)
|
||||
// 4. self.next_date_candidate = None (no change)
|
||||
// 5. return None
|
||||
assert_eq!(generator.next(), None);
|
||||
// Returns None because internal find_next_date failed
|
||||
|
||||
// State after the *first* None is returned:
|
||||
assert_eq!(generator.periods_remaining, 0); // Corrected assertion
|
||||
// Corrected assertion
|
||||
assert_eq!(generator.periods_remaining, 0);
|
||||
assert!(generator.next_date_candidate.is_none());
|
||||
|
||||
// Calling next() again should also return None
|
||||
15
src/utils/dateutils/mod.rs
Normal file
15
src/utils/dateutils/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! Generators for sequences of calendar and business dates.
|
||||
//!
|
||||
//! See [`dates`] for all-day calendars and [`bdates`] for business-day aware
|
||||
//! variants.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::utils::dateutils::{DatesList, DateFreq};
|
||||
//! let list = DatesList::new("2024-01-01".into(), "2024-01-02".into(), DateFreq::Daily);
|
||||
//! assert_eq!(list.count().unwrap(), 2);
|
||||
//! ```
|
||||
pub mod bdates;
|
||||
pub mod dates;
|
||||
|
||||
pub use bdates::{BDateFreq, BDatesGenerator, BDatesList};
|
||||
pub use dates::{DateFreq, DatesGenerator, DatesList};
|
||||
@@ -1,6 +1,15 @@
|
||||
pub mod bdates;
|
||||
pub use bdates::{BDateFreq, BDatesList, BDatesGenerator};
|
||||
|
||||
pub mod dates;
|
||||
pub use dates::{DateFreq, DatesList, DatesGenerator};
|
||||
//! Assorted helper utilities.
|
||||
//!
|
||||
//! Currently this module exposes date generation utilities in [`dateutils`](crate::utils::dateutils),
|
||||
//! including calendar and business date sequences.
|
||||
//!
|
||||
//! ```
|
||||
//! use rustframe::utils::DatesList;
|
||||
//! use rustframe::utils::DateFreq;
|
||||
//! let dates = DatesList::new("2024-01-01".into(), "2024-01-03".into(), DateFreq::Daily);
|
||||
//! assert_eq!(dates.count().unwrap(), 3);
|
||||
//! ```
|
||||
pub mod dateutils;
|
||||
|
||||
pub use dateutils::{BDateFreq, BDatesGenerator, BDatesList};
|
||||
pub use dateutils::{DateFreq, DatesGenerator, DatesList};
|
||||
|
||||
Reference in New Issue
Block a user