feat: checkpoint user state tracking service with PostgreSQL

- RESTful API: POST /heartbeat, POST /checkpoints, GET /status, GET /summaries
- State-change-only checkpoint model with extensible StateType enum
- PostgreSQL backend with sqlx, auto-migration on startup
- pg_cron scheduled aggregation (state_summaries) and offline detection
- Heartbeat-based liveness with 60s timeout auto-offline
- LEAD() window function for state duration calculation
- JSONB content field for extensible checkpoint metadata

BREAKING CHANGE: Complete rewrite from Hello World to full API service.
This commit is contained in:
2026-05-31 22:36:20 +08:00
parent 766b8a84c9
commit 13f7c1326a
20 changed files with 3315 additions and 2 deletions
+2
View File
@@ -1 +1,3 @@
/target
.env
Generated
+2066
View File
File diff suppressed because it is too large Load Diff
+9
View File
@@ -4,3 +4,12 @@ version = "0.1.0"
edition = "2024"
[dependencies]
axum = "0.8"
tokio = { version = "1", features = ["full"] }
tower-http = { version = "0.6", features = ["cors"] }
chrono = "0.4.44"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
async-trait = "0.1"
dotenvy = "0.15"
sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "chrono"] }
+142
View File
@@ -0,0 +1,142 @@
# ck-rs — Checkpoint 用户状态追踪服务
基于 [Axum](https://github.com/tokio-rs/axum) 的 RESTful API 服务,用于记录用户在不同状态下的持续时间。客户端定时上报检查点(checkpoint),服务端自动计算各状态累计时长。类比手机屏幕使用时间统计。
## 快速开始
```bash
# 启动服务(默认监听 127.0.0.1:3000,可通过 .env 中 LISTEN_ADDR 修改)
cargo run
```
## API 端点
| 方法 | 路径 | 说明 |
|------|------|------|
| `GET` | `/health` | 服务健康检查 |
| `POST` | `/users/{user_id}/checkpoints` | 上报当前状态(心跳) |
| `GET` | `/users/{user_id}/checkpoints` | 查询检查点历史(`?from=&to=&limit=` |
| `GET` | `/users/{user_id}/checkpoints/{id}` | 查询单个检查点 |
| `GET` | `/users/{user_id}/status` | 当前状态 + 各状态累计时长 |
### 状态类型
| 内置状态 | 说明 |
|----------|------|
| `Online` | 在线 |
| `Offline` | 离线 |
| `Idle` | 空闲 |
| `Working` | 工作中 |
| `Sleeping` | 睡眠 |
| `"任意字符串"` | 自定义状态,如 `Gaming``Meeting``Driving` 等 |
### 请求示例
```bash
BASE=http://localhost:3000
# 健康检查
curl $BASE/health
# 上报状态(自动记录服务端当前时间)
curl -X POST $BASE/users/alice/checkpoints \
-H "Content-Type: application/json" \
-d '{"state":"Working"}'
# 上报状态 + 附带元数据(设备、坐标等)
curl -X POST $BASE/users/alice/checkpoints \
-H "Content-Type: application/json" \
-d '{"state":"Idle","content":{"device":"MacBook","battery":85}}'
# 上报自定义状态
curl -X POST $BASE/users/alice/checkpoints \
-H "Content-Type: application/json" \
-d '{"state":"Gaming"}'
# 等待几秒后切换状态(产生时长数据)
sleep 3
curl -X POST $BASE/users/alice/checkpoints \
-H "Content-Type: application/json" \
-d '{"state":"Offline"}'
# 查询检查点历史
curl $BASE/users/alice/checkpoints
# 按时间范围查询
curl "$BASE/users/alice/checkpoints?from=1717161600&to=1717248000&limit=10"
# 查询单个检查点
curl $BASE/users/alice/checkpoints/1
# 查询状态汇总(当前状态 + 各状态累计时长)
curl $BASE/users/alice/status
# → {"user_id":"alice","current_state":"Offline","since":1717248000,
# "durations":[{"state":"Working","duration_secs":3}]}
```
## 环境变量
| 变量 | 默认值 | 说明 |
|------|--------|------|
| `LISTEN_ADDR` | `127.0.0.1:3000` | 监听地址 |
| `DATABASE_URL` | (无) | 数据库连接字符串(接入真实 DB 时设置) |
## 项目结构
```
src/
├── main.rs # 入口:加载配置 → 初始化 DB → 构建路由 → 启动
├── config.rs # 配置层(环境变量 + 默认值)
├── state.rs # AppState(全局共享状态,持有 DB)
├── error.rs # 统一错误类型 AppError(实现 IntoResponse
├── router.rs # 路由组装
├── models/
│ ├── mod.rs
│ └── checkpoint.rs # StateType / Checkpoint / UserStatusResponse 等
├── handlers/
│ ├── mod.rs
│ ├── health.rs # GET /health
│ └── checkpoints.rs # POST/GET /users/{id}/checkpoints
└── db/
├── mod.rs # Db trait(数据库抽象接口)
└── memory.rs # MemoryDb(内存模拟,开发期使用)
```
## 核心设计
- **状态可自由扩充**`StateType` 内置 5 种状态 + `Custom(String)` 变体,传入任意字符串自动作为新状态
- **content 可扩展**:每个检查点可附带 `Option<serde_json::Value>` 元数据
- **时长自动计算**:相邻检查点的时间差归属于前一个状态,`/status` 返回各状态累计秒数
- **timestamp 可选**:请求可带时间戳,不传则服务端取当前时间
## 接入真实数据库
当前使用内存模拟存储(`MemoryDb`),切换为 PostgreSQL 仅需 3 步:
### 1. 添加依赖
取消 `Cargo.toml` 中的注释:
```toml
sqlx = { version = "0.8", features = ["runtime-tokio", "postgres"] }
```
### 2. 实现 Db trait
新建 `src/db/postgres.rs`,对 `PgPool` 实现 `Db` trait。
### 3. 修改入口
`src/main.rs` 中将 `MemoryDb::new()` 替换为 `PgPool::connect(...).await`
## 技术栈
- [Axum 0.8](https://crates.io/crates/axum) — Web 框架
- [Tokio](https://crates.io/crates/tokio) — 异步运行时
- [Serde](https://crates.io/crates/serde) — 序列化 / 反序列化
- [Chrono](https://crates.io/crates/chrono) — 日期时间处理
- [Dotenvy](https://crates.io/crates/dotenvy) — .env 文件加载
## License
MIT
+20
View File
@@ -0,0 +1,20 @@
-- 001_init.sql
-- Checkpoint 服务初始建表
CREATE TABLE IF NOT EXISTS checkpoints (
id BIGSERIAL PRIMARY KEY,
user_id VARCHAR(128) NOT NULL,
state VARCHAR(64) NOT NULL,
timestamp BIGINT NOT NULL,
content JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
-- 核心查询索引:按用户 + 时间范围
CREATE INDEX IF NOT EXISTS idx_checkpoints_user_ts
ON checkpoints (user_id, timestamp);
-- 快速获取用户最新检查点
CREATE INDEX IF NOT EXISTS idx_checkpoints_user_latest
ON checkpoints (user_id, timestamp DESC);
+84
View File
@@ -0,0 +1,84 @@
-- 002_daily_summary.sql
-- 状态时长定时快照表 + 聚合函数 + pg_cron 调度说明
-- ============================================================
-- 1. 快照存储表
-- ============================================================
CREATE TABLE IF NOT EXISTS state_summaries (
id BIGSERIAL PRIMARY KEY,
user_id VARCHAR(128) NOT NULL,
state VARCHAR(64) NOT NULL,
duration_secs BIGINT NOT NULL DEFAULT 0,
period_start BIGINT NOT NULL,
period_end BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
UNIQUE (user_id, state, period_start, period_end)
);
CREATE INDEX IF NOT EXISTS idx_summaries_user_period
ON state_summaries (user_id, period_start DESC);
-- ============================================================
-- 2. 聚合函数:滑动窗口计算指定时间范围内各状态时长
-- - 调度器只需传入 period_start / period_end
-- - 内部用 LEAD() 窗口函数求相邻检查点差值
-- - ON CONFLICT 保证幂等(重复执行不产生重复行)
-- ============================================================
CREATE OR REPLACE FUNCTION aggregate_checkpoint_durations(
p_start BIGINT,
p_end BIGINT
) RETURNS BIGINT AS $$
DECLARE
affected BIGINT;
BEGIN
INSERT INTO state_summaries (user_id, state, duration_secs, period_start, period_end)
SELECT user_id, state, SUM(duration) AS duration_secs, p_start, p_end
FROM (
SELECT user_id, state,
LEAD(timestamp) OVER (PARTITION BY user_id ORDER BY timestamp) - timestamp AS duration
FROM checkpoints
WHERE timestamp >= p_start AND timestamp <= p_end
) sub
WHERE duration > 0
GROUP BY user_id, state
ON CONFLICT (user_id, state, period_start, period_end) DO UPDATE
SET duration_secs = EXCLUDED.duration_secs;
GET DIAGNOSTICS affected = ROW_COUNT;
RETURN affected;
END;
$$ LANGUAGE plpgsql;
-- ============================================================
-- 3. pg_cron 调度(需要超级用户手动执行一次)
-- ============================================================
--
-- -- 安装扩展(仅需一次)
-- CREATE EXTENSION IF NOT EXISTS pg_cron;
--
-- -- 每 5 分钟执行一次聚合
-- SELECT cron.schedule(
-- 'aggregate-5min',
-- '*/5 * * * *',
-- $$ SELECT aggregate_checkpoint_durations(
-- FLOOR(EXTRACT(EPOCH FROM now() - INTERVAL '5 minutes'))::BIGINT,
-- FLOOR(EXTRACT(EPOCH FROM now()))::BIGINT
-- ); $$
-- );
--
-- -- 每小时整点执行一次(日报表用)
-- SELECT cron.schedule(
-- 'aggregate-hourly',
-- '0 * * * *',
-- $$ SELECT aggregate_checkpoint_durations(
-- FLOOR(EXTRACT(EPOCH FROM now() - INTERVAL '1 hour'))::BIGINT,
-- FLOOR(EXTRACT(EPOCH FROM now()))::BIGINT
-- ); $$
-- );
--
-- -- 查看 cron 任务状态
-- SELECT * FROM cron.job;
--
-- -- 取消(如需)
-- SELECT cron.unschedule('aggregate-5min');
+65
View File
@@ -0,0 +1,65 @@
-- 003_sessions.sql
-- 用户会话表 + 心跳管理 + 离线检测 + pg_cron 调度
-- ============================================================
-- 1. 用户会话表(跟踪当前状态和心跳)
-- ============================================================
CREATE TABLE IF NOT EXISTS user_sessions (
user_id VARCHAR(128) PRIMARY KEY,
current_state VARCHAR(64) NOT NULL DEFAULT 'Offline',
last_heartbeat BIGINT NOT NULL,
last_state_change BIGINT NOT NULL,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
-- ============================================================
-- 2. 离线检测函数:心跳超时 60s 的用户自动补一条 Offline 检查点
-- pg_cron 每 1 分钟执行一次
-- ============================================================
CREATE OR REPLACE FUNCTION detect_offline_users(
timeout_secs BIGINT DEFAULT 60
) RETURNS SETOF BIGINT AS $$
DECLARE
r RECORD;
now_ts BIGINT;
BEGIN
now_ts := EXTRACT(EPOCH FROM now())::BIGINT;
FOR r IN
SELECT user_id, last_heartbeat
FROM user_sessions
WHERE current_state != 'Offline'
AND (now_ts - last_heartbeat) > timeout_secs
LOOP
-- 插入离线检查点(时间戳 = 最后心跳 + 超时时间)
INSERT INTO checkpoints (user_id, state, timestamp)
VALUES (r.user_id, 'Offline', r.last_heartbeat + timeout_secs);
-- 更新会话状态
UPDATE user_sessions
SET current_state = 'Offline',
last_state_change = r.last_heartbeat + timeout_secs,
updated_at = now()
WHERE user_id = r.user_id;
RETURN NEXT 1;
END LOOP;
END;
$$ LANGUAGE plpgsql;
-- ============================================================
-- 3. pg_cron 调度(需要超级用户手动执行一次)
-- ============================================================
--
-- -- 安装扩展(仅需一次)
-- CREATE EXTENSION IF NOT EXISTS pg_cron;
--
-- -- 每 1 分钟检测离线用户
-- SELECT cron.schedule(
-- 'detect-offline',
-- '* * * * *',
-- $$ SELECT detect_offline_users(60); $$
-- );
--
-- -- 查看状态
-- SELECT * FROM cron.job;
+33
View File
@@ -0,0 +1,33 @@
use std::net::SocketAddr;
/// 服务器配置
#[derive(Debug, Clone)]
pub struct Config {
/// 监听地址
pub listen_addr: SocketAddr,
/// 数据库连接字符串(接入真实数据库时使用)
#[allow(dead_code)]
pub database_url: Option<String>,
}
impl Config {
/// 从环境变量加载配置,缺失时使用默认值
pub fn from_env() -> Self {
Self {
listen_addr: std::env::var("LISTEN_ADDR")
.unwrap_or_else(|_| "127.0.0.1:3000".into())
.parse()
.expect("invalid LISTEN_ADDR"),
database_url: std::env::var("DATABASE_URL").ok(),
}
}
}
impl Default for Config {
fn default() -> Self {
Self {
listen_addr: SocketAddr::from(([127, 0, 0, 1], 3000)),
database_url: None,
}
}
}
+70
View File
@@ -0,0 +1,70 @@
pub mod postgres;
use std::sync::Arc;
use crate::models::checkpoint::{Checkpoint, StateSummary, UserStatusResponse};
/// 心跳返回信息
pub struct HeartbeatInfo {
pub current_state: String,
pub last_heartbeat: i64,
}
/// 数据库操作抽象 trait —— 所有持久化实现必须满足此接口
#[async_trait::async_trait]
pub trait Db: Send + Sync + 'static {
/// 创建一个检查点记录(仅在状态变更时调用)
async fn create_checkpoint(
&self,
user_id: &str,
state: &str,
timestamp: i64,
content: Option<serde_json::Value>,
) -> Result<Checkpoint, String>;
/// 获取单个检查点
async fn get_checkpoint(&self, id: u64) -> Result<Option<Checkpoint>, String>;
/// 列出用户的检查点(支持时间范围 + 条数限制)
async fn list_checkpoints(
&self,
user_id: &str,
from: Option<i64>,
to: Option<i64>,
limit: Option<usize>,
) -> Result<Vec<Checkpoint>, String>;
/// 获取用户最新的检查点(预留,用于心跳检测等场景)
#[allow(dead_code)]
async fn get_latest_checkpoint(&self, user_id: &str) -> Result<Option<Checkpoint>, String>;
/// 获取用户状态汇总(当前状态 + 各状态时长)
async fn get_status_summary(&self, user_id: &str) -> Result<UserStatusResponse, String>;
/// 查询用户的定时快照历史
async fn list_state_summaries(
&self,
user_id: &str,
from: Option<i64>,
to: Option<i64>,
limit: Option<usize>,
) -> Result<Vec<StateSummary>, String>;
/// 心跳:刷新 last_heartbeat,不产生检查点
/// state 为空字符串时只刷新心跳不改变状态
async fn heartbeat(&self, user_id: &str, state: &str) -> Result<HeartbeatInfo, String>;
/// [管理] 手动触发离线检测(开发/调试用,生产由 pg_cron 接管)
async fn admin_detect_offline(&self, timeout_secs: i64) -> Result<u64, String>;
/// [管理] 手动触发一次聚合(开发/调试用,生产由 pg_cron 接管)
async fn admin_trigger_aggregation(&self, period_start: i64, period_end: i64) -> Result<u64, String>;
/// [调试] 将用户 last_heartbeat 设为指定值(模拟超时)
async fn debug_set_last_heartbeat(&self, user_id: &str, ts: i64) -> Result<(), String>;
}
/// 简便构造:将 Db 实现包装为 Arc
pub fn into_shared(db: impl Db) -> Arc<dyn Db> {
Arc::new(db)
}
+306
View File
@@ -0,0 +1,306 @@
use chrono::Utc;
use serde_json::Value;
use sqlx::PgPool;
use crate::db::Db;
use crate::db::HeartbeatInfo;
use crate::models::checkpoint::{Checkpoint, StateDuration, StateSummary, StateType, UserStatusResponse};
/// PostgreSQL 数据库实现(通过 sqlx
pub struct PgDb {
pool: PgPool,
}
impl PgDb {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}
fn parse_state(s: &str) -> StateType {
match s {
"Online" => StateType::Online,
"Offline" => StateType::Offline,
"Idle" => StateType::Idle,
"Working" => StateType::Working,
"Sleeping" => StateType::Sleeping,
other => StateType::Custom(other.to_string()),
}
}
fn row_to_checkpoint(r: PgCheckpointRow) -> Checkpoint {
Checkpoint {
id: r.id as u64,
user_id: r.user_id,
state: parse_state(&r.state),
timestamp: r.timestamp,
content: r.content,
}
}
#[async_trait::async_trait]
impl Db for PgDb {
async fn create_checkpoint(
&self,
user_id: &str,
state: &str,
timestamp: i64,
content: Option<Value>,
) -> Result<Checkpoint, String> {
let row: PgCheckpointRow = sqlx::query_as(
"INSERT INTO checkpoints (user_id, state, timestamp, content) \
VALUES ($1, $2, $3, $4) \
RETURNING id, user_id, state, timestamp, content",
)
.bind(user_id)
.bind(state)
.bind(timestamp)
.bind(content)
.fetch_one(&self.pool)
.await
.map_err(|e| format!("insert checkpoint: {e}"))?;
Ok(row_to_checkpoint(row))
}
async fn get_checkpoint(&self, id: u64) -> Result<Option<Checkpoint>, String> {
let row: Option<PgCheckpointRow> = sqlx::query_as(
"SELECT id, user_id, state, timestamp, content \
FROM checkpoints WHERE id = $1",
)
.bind(id as i64)
.fetch_optional(&self.pool)
.await
.map_err(|e| format!("get checkpoint: {e}"))?;
Ok(row.map(row_to_checkpoint))
}
async fn list_checkpoints(
&self,
user_id: &str,
from: Option<i64>,
to: Option<i64>,
limit: Option<usize>,
) -> Result<Vec<Checkpoint>, String> {
let rows: Vec<PgCheckpointRow> = sqlx::query_as(
"SELECT id, user_id, state, timestamp, content \
FROM checkpoints WHERE user_id = $1 ORDER BY timestamp",
)
.bind(user_id)
.fetch_all(&self.pool)
.await
.map_err(|e| format!("list checkpoints: {e}"))?;
let mut result: Vec<Checkpoint> = rows
.into_iter()
.map(row_to_checkpoint)
.filter(|c| from.map_or(true, |f| c.timestamp >= f))
.filter(|c| to.map_or(true, |t| c.timestamp <= t))
.collect();
if let Some(n) = limit {
result.truncate(n);
}
Ok(result)
}
async fn get_latest_checkpoint(&self, user_id: &str) -> Result<Option<Checkpoint>, String> {
let row: Option<PgCheckpointRow> = sqlx::query_as(
"SELECT id, user_id, state, timestamp, content \
FROM checkpoints \
WHERE user_id = $1 \
ORDER BY timestamp DESC \
LIMIT 1",
)
.bind(user_id)
.fetch_optional(&self.pool)
.await
.map_err(|e| format!("latest checkpoint: {e}"))?;
Ok(row.map(row_to_checkpoint))
}
async fn get_status_summary(&self, user_id: &str) -> Result<UserStatusResponse, String> {
// 用窗口函数计算相邻检查点之间的时长
let rows: Vec<StateDurationRow> = sqlx::query_as(
r#"SELECT state, duration FROM (
SELECT state,
LEAD(timestamp) OVER (PARTITION BY user_id ORDER BY timestamp) - timestamp AS duration
FROM checkpoints
WHERE user_id = $1
) sub
WHERE duration > 0"#,
)
.bind(user_id)
.fetch_all(&self.pool)
.await
.map_err(|e| format!("status summary: {e}"))?;
let mut durations_map: std::collections::BTreeMap<String, i64> =
std::collections::BTreeMap::new();
for row in &rows {
*durations_map.entry(row.state.clone()).or_insert(0) += row.duration;
}
let durations: Vec<StateDuration> = durations_map
.into_iter()
.map(|(k, v)| StateDuration {
state: parse_state(&k),
duration_secs: v,
})
.collect();
let latest = self.get_latest_checkpoint(user_id).await?;
let (current_state, since) = match latest {
Some(cp) => (cp.state, cp.timestamp),
None => (StateType::Offline, 0),
};
Ok(UserStatusResponse {
user_id: user_id.to_string(),
current_state,
since,
durations,
})
}
async fn list_state_summaries(
&self,
user_id: &str,
from: Option<i64>,
to: Option<i64>,
limit: Option<usize>,
) -> Result<Vec<StateSummary>, String> {
use crate::models::checkpoint::StateSummary;
let rows: Vec<SummaryRow> = sqlx::query_as(
"SELECT id, user_id, state, duration_secs, period_start, period_end, \
TO_CHAR(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at \
FROM state_summaries \
WHERE user_id = $1 \
AND ($2::BIGINT IS NULL OR period_start >= $2) \
AND ($3::BIGINT IS NULL OR period_end <= $3) \
ORDER BY period_start DESC \
LIMIT $4",
)
.bind(user_id)
.bind(from)
.bind(to)
.bind(limit.unwrap_or(100) as i64)
.fetch_all(&self.pool)
.await
.map_err(|e| format!("list summaries: {e}"))?;
Ok(rows
.into_iter()
.map(|r| StateSummary {
id: r.id as u64,
user_id: r.user_id,
state: parse_state(&r.state),
duration_secs: r.duration_secs,
period_start: r.period_start,
period_end: r.period_end,
created_at: r.created_at,
})
.collect())
}
async fn heartbeat(&self, user_id: &str, state: &str) -> Result<HeartbeatInfo, String> {
let now = Utc::now().timestamp();
let effective_state = if state.is_empty() { None } else { Some(state) };
let row: HeartbeatRow = sqlx::query_as(
r#"INSERT INTO user_sessions (user_id, current_state, last_heartbeat, last_state_change)
VALUES ($1, COALESCE($2, 'Offline'), $3, $3)
ON CONFLICT (user_id) DO UPDATE SET
current_state = COALESCE($2, user_sessions.current_state),
last_heartbeat = $3,
last_state_change = CASE WHEN $2 IS NOT NULL THEN $3 ELSE user_sessions.last_state_change END,
updated_at = now()
RETURNING current_state, last_heartbeat"#,
)
.bind(user_id)
.bind(effective_state)
.bind(now)
.fetch_one(&self.pool)
.await
.map_err(|e| format!("heartbeat: {e}"))?;
Ok(HeartbeatInfo {
current_state: row.current_state,
last_heartbeat: row.last_heartbeat,
})
}
async fn admin_detect_offline(&self, timeout_secs: i64) -> Result<u64, String> {
let row: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM detect_offline_users($1)"
)
.bind(timeout_secs)
.fetch_one(&self.pool)
.await
.map_err(|e| format!("detect offline: {e}"))?;
Ok(row.0 as u64)
}
async fn admin_trigger_aggregation(&self, period_start: i64, period_end: i64) -> Result<u64, String> {
let row: (i64,) = sqlx::query_as(
"SELECT aggregate_checkpoint_durations($1, $2)"
)
.bind(period_start)
.bind(period_end)
.fetch_one(&self.pool)
.await
.map_err(|e| format!("aggregation: {e}"))?;
Ok(row.0 as u64)
}
async fn debug_set_last_heartbeat(&self, user_id: &str, ts: i64) -> Result<(), String> {
sqlx::query(
"UPDATE user_sessions SET last_heartbeat = $1 WHERE user_id = $2"
)
.bind(ts)
.bind(user_id)
.execute(&self.pool)
.await
.map_err(|e| format!("debug set heartbeat: {e}"))?;
Ok(())
}
}
#[derive(Debug, sqlx::FromRow)]
struct HeartbeatRow {
current_state: String,
last_heartbeat: i64,
}
// ---------- sqlx 行映射 ----------
#[derive(Debug, sqlx::FromRow)]
struct PgCheckpointRow {
id: i64,
user_id: String,
state: String,
timestamp: i64,
content: Option<Value>,
}
#[derive(Debug, sqlx::FromRow)]
struct StateDurationRow {
state: String,
duration: i64,
}
#[derive(Debug, sqlx::FromRow)]
struct SummaryRow {
id: i64,
user_id: String,
state: String,
duration_secs: i64,
period_start: i64,
period_end: i64,
created_at: String,
}
+54
View File
@@ -0,0 +1,54 @@
use axum::{http::StatusCode, response::IntoResponse, Json};
use serde::Serialize;
/// 统一错误响应体
#[derive(Debug, Serialize)]
pub struct ErrorBody {
pub error: String,
}
/// 应用层统一错误类型
#[derive(Debug)]
pub enum AppError {
NotFound(String),
#[allow(dead_code)]
BadRequest(String),
Internal(String),
}
impl AppError {
pub fn status_code(&self) -> StatusCode {
match self {
AppError::NotFound(_) => StatusCode::NOT_FOUND,
AppError::BadRequest(_) => StatusCode::BAD_REQUEST,
AppError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
impl std::fmt::Display for AppError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AppError::NotFound(msg) => write!(f, "Not Found: {msg}"),
AppError::BadRequest(msg) => write!(f, "Bad Request: {msg}"),
AppError::Internal(msg) => write!(f, "Internal Error: {msg}"),
}
}
}
impl IntoResponse for AppError {
fn into_response(self) -> axum::response::Response {
let status = self.status_code();
let body = Json(ErrorBody {
error: self.to_string(),
});
(status, body).into_response()
}
}
// 允许 Axum handler 直接使用 `?` 将 sqlx 错误转为 AppError
impl From<sqlx::Error> for AppError {
fn from(e: sqlx::Error) -> Self {
AppError::Internal(format!("database error: {e}"))
}
}
+77
View File
@@ -0,0 +1,77 @@
use axum::{
extract::{Path, State},
Json,
};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use crate::error::AppError;
use crate::state::AppState;
#[derive(Serialize)]
pub struct AdminResult {
pub affected: u64,
pub message: String,
}
#[derive(Deserialize)]
pub struct SetHeartbeatRequest {
pub seconds_ago: i64,
}
/// POST /admin/offline-check
/// 手动触发离线检测(开发/调试用)
pub async fn offline_check(
State(state): State<AppState>,
) -> Result<Json<AdminResult>, AppError> {
let n = state
.db
.admin_detect_offline(60)
.await
.map_err(AppError::Internal)?;
Ok(Json(AdminResult {
affected: n,
message: format!("{n} users marked as offline"),
}))
}
/// POST /admin/aggregate
/// 手动触发一次聚合(开发/调试用)
pub async fn aggregate_now(
State(state): State<AppState>,
) -> Result<Json<AdminResult>, AppError> {
let now = Utc::now().timestamp();
let start = now - 3600; // 聚合最近 1 小时
let n = state
.db
.admin_trigger_aggregation(start, now)
.await
.map_err(AppError::Internal)?;
Ok(Json(AdminResult {
affected: n,
message: format!("{n} summary rows upserted"),
}))
}
/// POST /admin/users/{user_id}/set-heartbeat
/// 调试用:将用户心跳时间设为 N 秒前(模拟超时)
pub async fn set_heartbeat_old(
State(state): State<AppState>,
Path(user_id): Path<String>,
Json(payload): Json<SetHeartbeatRequest>,
) -> Result<Json<AdminResult>, AppError> {
let ts = Utc::now().timestamp() - payload.seconds_ago;
state
.db
.debug_set_last_heartbeat(&user_id, ts)
.await
.map_err(AppError::Internal)?;
Ok(Json(AdminResult {
affected: 1,
message: format!("{user_id} last_heartbeat set to {}s ago", payload.seconds_ago),
}))
}
+122
View File
@@ -0,0 +1,122 @@
use axum::{
extract::{Path, Query, State},
http::StatusCode,
Json,
};
use chrono::Utc;
use crate::error::AppError;
use crate::models::checkpoint::{
Checkpoint, CreateCheckpointRequest, HeartbeatResponse, ListCheckpointsQuery,
ListSummariesQuery, StateSummary, UserStatusResponse,
};
use crate::state::AppState;
// ---------- POST /users/{user_id}/checkpoints ----------
// 仅在状态变更时调用,同步更新会话
pub async fn create_checkpoint(
State(state): State<AppState>,
Path(user_id): Path<String>,
Json(payload): Json<CreateCheckpointRequest>,
) -> Result<(StatusCode, Json<Checkpoint>), AppError> {
let timestamp = payload.timestamp.unwrap_or_else(|| Utc::now().timestamp());
let cp = state
.db
.create_checkpoint(&user_id, &payload.state.to_string(), timestamp, payload.content)
.await
.map_err(AppError::Internal)?;
// 状态变更时同步更新会话(心跳也顺带刷新)
let _ = state.db.heartbeat(&user_id, &payload.state.to_string()).await;
Ok((StatusCode::CREATED, Json(cp)))
}
// ---------- GET /users/{user_id}/checkpoints ----------
pub async fn list_checkpoints(
State(state): State<AppState>,
Path(user_id): Path<String>,
Query(query): Query<ListCheckpointsQuery>,
) -> Result<Json<Vec<Checkpoint>>, AppError> {
let cps = state
.db
.list_checkpoints(&user_id, query.from, query.to, query.limit)
.await
.map_err(AppError::Internal)?;
Ok(Json(cps))
}
// ---------- GET /users/{user_id}/checkpoints/{id} ----------
pub async fn get_checkpoint(
State(state): State<AppState>,
Path((user_id, id)): Path<(String, u64)>,
) -> Result<Json<Checkpoint>, AppError> {
let cp = state
.db
.get_checkpoint(id)
.await
.map_err(AppError::Internal)?
.ok_or_else(|| AppError::NotFound(format!("checkpoint {id} not found")))?;
if cp.user_id != user_id {
return Err(AppError::NotFound(format!("checkpoint {id} not found")));
}
Ok(Json(cp))
}
// ---------- GET /users/{user_id}/status ----------
pub async fn get_user_status(
State(state): State<AppState>,
Path(user_id): Path<String>,
) -> Result<Json<UserStatusResponse>, AppError> {
let summary = state
.db
.get_status_summary(&user_id)
.await
.map_err(AppError::Internal)?;
Ok(Json(summary))
}
// ---------- GET /users/{user_id}/summaries ----------
pub async fn list_summaries(
State(state): State<AppState>,
Path(user_id): Path<String>,
Query(query): Query<ListSummariesQuery>,
) -> Result<Json<Vec<StateSummary>>, AppError> {
let rows = state
.db
.list_state_summaries(&user_id, query.from, query.to, query.limit)
.await
.map_err(AppError::Internal)?;
Ok(Json(rows))
}
// ---------- POST /users/{user_id}/heartbeat ----------
// 心跳验证(每 30s),不产生检查点,仅刷新 last_heartbeat
pub async fn heartbeat(
State(state): State<AppState>,
Path(user_id): Path<String>,
) -> Result<Json<HeartbeatResponse>, AppError> {
let hb = state
.db
.heartbeat(&user_id, "")
.await
.map_err(AppError::Internal)?;
Ok(Json(HeartbeatResponse {
user_id,
current_state: hb.current_state,
last_heartbeat: hb.last_heartbeat,
}))
}
+6
View File
@@ -0,0 +1,6 @@
use axum::{response::IntoResponse, Json};
/// GET /health
pub async fn health_check() -> impl IntoResponse {
Json(serde_json::json!({ "status": "ok" }))
}
+3
View File
@@ -0,0 +1,3 @@
pub mod admin;
pub mod checkpoints;
pub mod health;
+78 -2
View File
@@ -1,3 +1,79 @@
fn main() {
println!("Hello, world!");
mod config;
mod db;
mod error;
mod handlers;
mod models;
mod router;
mod state;
use config::Config;
use db::postgres::PgDb;
use state::AppState;
/// 按 ; 拆分 SQL,跳过 $$...$$ 内的分号(保护 PG 函数体)
fn split_sql(sql: &str) -> Vec<String> {
let mut stmts = Vec::new();
let mut buf = String::new();
let mut dollar_depth = 0;
let chars: Vec<char> = sql.chars().collect();
let mut i = 0;
while i < chars.len() {
if i + 1 < chars.len() && chars[i] == '$' && chars[i + 1] == '$' {
if dollar_depth == 0 { dollar_depth = 1; } else { dollar_depth = 0; }
buf.push_str("$$");
i += 2;
} else if chars[i] == ';' && dollar_depth == 0 {
stmts.push(buf.trim().to_string());
buf.clear();
i += 1;
} else {
buf.push(chars[i]);
i += 1;
}
}
let remainder = buf.trim().to_string();
if !remainder.is_empty() {
stmts.push(remainder);
}
stmts
}
async fn run_migration(pool: &sqlx::PgPool, sql_file: &str, name: &str) {
for stmt in split_sql(sql_file) {
let trimmed = stmt.trim();
if !trimmed.is_empty() {
sqlx::query(trimmed)
.execute(pool)
.await
.unwrap_or_else(|e| panic!("Migration {name} failed: {e}\nSQL: {trimmed}"));
}
}
}
#[tokio::main]
async fn main() {
let _ = dotenvy::dotenv();
let cfg = Config::from_env();
let database_url = cfg.database_url
.expect("DATABASE_URL must be set (PostgreSQL is the only supported backend)");
println!("🔗 Connecting to PostgreSQL...");
let pool = sqlx::PgPool::connect(&database_url)
.await
.expect("Failed to connect to PostgreSQL");
run_migration(&pool, include_str!("../migrations/001_init.sql"), "001_init").await;
run_migration(&pool, include_str!("../migrations/002_daily_summary.sql"), "002_summary").await;
run_migration(&pool, include_str!("../migrations/003_sessions.sql"), "003_sessions").await;
println!("✅ PostgreSQL connected, migrations applied");
let db = db::into_shared(PgDb::new(pool));
let state = AppState::new(db);
let app = router::build(state);
println!("🚀 Server running at http://{}", cfg.listen_addr);
let listener = tokio::net::TcpListener::bind(cfg.listen_addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
+129
View File
@@ -0,0 +1,129 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
/// 用户状态枚举 —— 通过 `Custom(String)` 变体可自由扩充新状态
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StateType {
Online,
Offline,
Idle,
Working,
Sleeping,
/// 可自由扩充的自定义状态,如 "Gaming", "Meeting", "Driving" 等
Custom(String),
}
// 自定义序列化:所有变体序列化为扁平字符串
impl Serialize for StateType {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.to_string())
}
}
// 自定义反序列化:已知名称匹配变体,未知名称视为 Custom
impl<'de> Deserialize<'de> for StateType {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
Ok(match s.as_str() {
"Online" => StateType::Online,
"Offline" => StateType::Offline,
"Idle" => StateType::Idle,
"Working" => StateType::Working,
"Sleeping" => StateType::Sleeping,
other => StateType::Custom(other.to_string()),
})
}
}
impl std::fmt::Display for StateType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StateType::Online => write!(f, "Online"),
StateType::Offline => write!(f, "Offline"),
StateType::Idle => write!(f, "Idle"),
StateType::Working => write!(f, "Working"),
StateType::Sleeping => write!(f, "Sleeping"),
StateType::Custom(s) => write!(f, "{s}"),
}
}
}
/// 检查点 —— 用户在某时刻的状态快照
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: u64,
pub user_id: String,
pub state: StateType,
/// Unix 时间戳(秒)
pub timestamp: i64,
/// 可扩展的 JSON 元数据(设备信息、坐标等)
pub content: Option<Value>,
}
/// 创建检查点的请求体
#[derive(Debug, Deserialize)]
pub struct CreateCheckpointRequest {
pub state: StateType,
/// 时间戳(可选),不传则服务端填充当前时间
pub timestamp: Option<i64>,
/// 可选的附加元数据
pub content: Option<Value>,
}
/// 单个状态的持续时间
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateDuration {
pub state: StateType,
/// 该状态累计时长(秒)
pub duration_secs: i64,
}
/// 用户状态汇总响应
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserStatusResponse {
pub user_id: String,
/// 当前所处状态
pub current_state: StateType,
/// 自何时进入当前状态(Unix 秒)
pub since: i64,
/// 各状态累计时长
pub durations: Vec<StateDuration>,
}
/// 查询检查点列表的参数
#[derive(Debug, Deserialize)]
pub struct ListCheckpointsQuery {
pub from: Option<i64>,
pub to: Option<i64>,
pub limit: Option<usize>,
}
// ---------- 定时快照模型 ----------
/// 状态时长快照(来自 state_summaries 表)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateSummary {
pub id: u64,
pub user_id: String,
pub state: StateType,
pub duration_secs: i64,
pub period_start: i64,
pub period_end: i64,
pub created_at: String,
}
/// 查询快照的参数
#[derive(Debug, Deserialize)]
pub struct ListSummariesQuery {
pub from: Option<i64>,
pub to: Option<i64>,
pub limit: Option<usize>,
}
/// 心跳响应
#[derive(Debug, Serialize)]
pub struct HeartbeatResponse {
pub user_id: String,
pub current_state: String,
pub last_heartbeat: i64,
}
+1
View File
@@ -0,0 +1 @@
pub mod checkpoint;
+32
View File
@@ -0,0 +1,32 @@
use axum::{routing::{get, post}, Router};
use crate::handlers::{admin, checkpoints, health};
use crate::state::AppState;
/// 组装所有路由
pub fn build(state: AppState) -> Router {
Router::new()
// 健康检查
.route("/health", get(health::health_check))
// 心跳验证(每 30s,不产生检查点)
.route("/users/{user_id}/heartbeat", post(checkpoints::heartbeat))
// 检查点: 列表 + 创建(状态变更时才调用)
.route(
"/users/{user_id}/checkpoints",
get(checkpoints::list_checkpoints).post(checkpoints::create_checkpoint),
)
// 单个检查点
.route(
"/users/{user_id}/checkpoints/{id}",
get(checkpoints::get_checkpoint),
)
// 用户状态汇总
.route("/users/{user_id}/status", get(checkpoints::get_user_status))
// 定时快照历史
.route("/users/{user_id}/summaries", get(checkpoints::list_summaries))
// 管理端点(开发/调试用)
.route("/admin/offline-check", post(admin::offline_check))
.route("/admin/aggregate", post(admin::aggregate_now))
.route("/admin/users/{user_id}/set-heartbeat", post(admin::set_heartbeat_old))
.with_state(state)
}
+16
View File
@@ -0,0 +1,16 @@
use std::sync::Arc;
use crate::db::Db;
/// 全局应用状态,所有 handler 通过 `State<AppState>` 共享
#[derive(Clone)]
pub struct AppState {
/// 数据库抽象层 —— 开发期用内存模拟,上线后替换为真实 DB
pub db: Arc<dyn Db>,
}
impl AppState {
pub fn new(db: Arc<dyn Db>) -> Self {
Self { db }
}
}