feat: Implement active workout flow with status management
- Added `status`, `total_sets`, and `total_volume` fields to the Workout model. - Introduced `source_kind`, `title_snapshot`, and `image_s3_url_snapshot` fields to the WorkoutItem model. - Created endpoints for managing active workouts, including finishing and discarding workouts. - Updated workout creation to ensure only one active workout exists per user. - Implemented batch addition of workout sets and updates to workout set details. - Enhanced database schema with Alembic migrations to support new fields and constraints. - Added validation to ensure at least one field is provided for workout set updates. - Updated calorie estimation logic to reflect new workout set structure.
This commit is contained in:
+285
-40
@@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException, Query, status
|
||||
from fastapi import Body, Depends, FastAPI, Header, HTTPException, Query, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
@@ -19,11 +19,14 @@ from app.schemas import (
|
||||
ProgressionPoint,
|
||||
ProgressionRead,
|
||||
WorkoutCreate,
|
||||
WorkoutFinishRequest,
|
||||
WorkoutItemCreate,
|
||||
WorkoutItemRead,
|
||||
WorkoutRead,
|
||||
WorkoutSetBatchCreate,
|
||||
WorkoutSetCreate,
|
||||
WorkoutSetRead,
|
||||
WorkoutSetUpdate,
|
||||
WorkoutUpdate,
|
||||
)
|
||||
|
||||
@@ -127,6 +130,34 @@ def accessible_exercises(db: Session, user_id: uuid.UUID):
|
||||
)
|
||||
|
||||
|
||||
def load_workout(db: Session, workout_id: uuid.UUID, user_id: uuid.UUID) -> Workout:
|
||||
workout = db.scalar(
|
||||
select(Workout)
|
||||
.where(Workout.id == workout_id, Workout.user_id == user_id)
|
||||
.options(selectinload(Workout.items).selectinload(WorkoutItem.sets))
|
||||
)
|
||||
if not workout:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workout not found")
|
||||
return workout
|
||||
|
||||
|
||||
def get_active_workout_for_user(db: Session, user_id: uuid.UUID) -> Workout | None:
|
||||
return db.scalar(
|
||||
select(Workout)
|
||||
.where(Workout.user_id == user_id, Workout.status == "active")
|
||||
.options(selectinload(Workout.items).selectinload(WorkoutItem.sets))
|
||||
.order_by(Workout.started_at.desc())
|
||||
)
|
||||
|
||||
|
||||
def ensure_active_workout(workout: Workout) -> None:
|
||||
if workout.status != "active":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Workout is not active",
|
||||
)
|
||||
|
||||
|
||||
@app.get(
|
||||
"/internal/catalog/equipment",
|
||||
dependencies=[InternalAuth],
|
||||
@@ -206,8 +237,14 @@ def list_workouts(db: Db, user_id: CurrentUserId) -> list[Workout]:
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
def create_workout(payload: WorkoutCreate, db: Db, user_id: CurrentUserId) -> Workout:
|
||||
if get_active_workout_for_user(db, user_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Active workout already exists",
|
||||
)
|
||||
workout = Workout(
|
||||
user_id=user_id,
|
||||
status="active",
|
||||
started_at=payload.started_at or datetime.now(UTC),
|
||||
notes=payload.notes,
|
||||
)
|
||||
@@ -217,16 +254,18 @@ def create_workout(payload: WorkoutCreate, db: Db, user_id: CurrentUserId) -> Wo
|
||||
return workout
|
||||
|
||||
|
||||
@app.get(
|
||||
"/internal/workouts/active",
|
||||
dependencies=[InternalAuth],
|
||||
response_model=WorkoutRead | None,
|
||||
)
|
||||
def get_active_workout(db: Db, user_id: CurrentUserId) -> Workout | None:
|
||||
return get_active_workout_for_user(db, user_id)
|
||||
|
||||
|
||||
@app.get("/internal/workouts/{workout_id}", dependencies=[InternalAuth], response_model=WorkoutRead)
|
||||
def get_workout(workout_id: uuid.UUID, db: Db, user_id: CurrentUserId) -> Workout:
|
||||
workout = db.scalar(
|
||||
select(Workout)
|
||||
.where(Workout.id == workout_id, Workout.user_id == user_id)
|
||||
.options(selectinload(Workout.items).selectinload(WorkoutItem.sets))
|
||||
)
|
||||
if not workout:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workout not found")
|
||||
return workout
|
||||
return load_workout(db, workout_id, user_id)
|
||||
|
||||
|
||||
@app.patch(
|
||||
@@ -237,12 +276,52 @@ def get_workout(workout_id: uuid.UUID, db: Db, user_id: CurrentUserId) -> Workou
|
||||
def update_workout(
|
||||
workout_id: uuid.UUID, payload: WorkoutUpdate, db: Db, user_id: CurrentUserId
|
||||
) -> Workout:
|
||||
workout = get_workout(workout_id, db, user_id)
|
||||
workout = load_workout(db, workout_id, user_id)
|
||||
if payload.finished_at is not None:
|
||||
workout.finished_at = payload.finished_at
|
||||
workout.status = "finished"
|
||||
if payload.notes is not None:
|
||||
workout.notes = payload.notes
|
||||
recalculate_workout_calories(db, workout.id)
|
||||
recalculate_workout_totals(db, workout.id)
|
||||
db.commit()
|
||||
db.refresh(workout)
|
||||
return workout
|
||||
|
||||
|
||||
@app.post(
|
||||
"/internal/workouts/{workout_id}/finish",
|
||||
dependencies=[InternalAuth],
|
||||
response_model=WorkoutRead,
|
||||
)
|
||||
def finish_workout(
|
||||
workout_id: uuid.UUID,
|
||||
db: Db,
|
||||
user_id: CurrentUserId,
|
||||
payload: Annotated[WorkoutFinishRequest | None, Body()] = None,
|
||||
) -> Workout:
|
||||
workout = load_workout(db, workout_id, user_id)
|
||||
ensure_active_workout(workout)
|
||||
if payload and payload.notes is not None:
|
||||
workout.notes = payload.notes
|
||||
recalculate_workout_totals(db, workout.id)
|
||||
workout.finished_at = datetime.now(UTC)
|
||||
workout.status = "finished"
|
||||
db.commit()
|
||||
db.refresh(workout)
|
||||
return workout
|
||||
|
||||
|
||||
@app.post(
|
||||
"/internal/workouts/{workout_id}/discard",
|
||||
dependencies=[InternalAuth],
|
||||
response_model=WorkoutRead,
|
||||
)
|
||||
def discard_workout(workout_id: uuid.UUID, db: Db, user_id: CurrentUserId) -> Workout:
|
||||
workout = load_workout(db, workout_id, user_id)
|
||||
ensure_active_workout(workout)
|
||||
recalculate_workout_totals(db, workout.id)
|
||||
workout.finished_at = datetime.now(UTC)
|
||||
workout.status = "discarded"
|
||||
db.commit()
|
||||
db.refresh(workout)
|
||||
return workout
|
||||
@@ -257,21 +336,37 @@ def update_workout(
|
||||
def add_workout_item(
|
||||
workout_id: uuid.UUID, payload: WorkoutItemCreate, db: Db, user_id: CurrentUserId
|
||||
) -> WorkoutItem:
|
||||
workout = get_workout(workout_id, db, user_id)
|
||||
workout = load_workout(db, workout_id, user_id)
|
||||
ensure_active_workout(workout)
|
||||
source_kind: str
|
||||
title_snapshot: str
|
||||
image_s3_url_snapshot: str | None
|
||||
if payload.exercise_id:
|
||||
exercise = db.get(Exercise, payload.exercise_id)
|
||||
if not exercise or (not exercise.is_builtin and exercise.owner_user_id != user_id):
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Exercise not found")
|
||||
if payload.equipment_id:
|
||||
source_kind = "exercise"
|
||||
title_snapshot = exercise.name
|
||||
image_s3_url_snapshot = exercise.image_s3_url
|
||||
elif payload.equipment_id:
|
||||
equipment = db.get(Equipment, payload.equipment_id)
|
||||
if not equipment or (not equipment.is_builtin and equipment.owner_user_id != user_id):
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Equipment not found")
|
||||
source_kind = "equipment"
|
||||
title_snapshot = equipment.name
|
||||
image_s3_url_snapshot = equipment.image_s3_url
|
||||
|
||||
next_index = payload.order_index
|
||||
if next_index is None:
|
||||
next_index = len(workout.items)
|
||||
max_index = db.scalar(
|
||||
select(func.max(WorkoutItem.order_index)).where(WorkoutItem.workout_id == workout.id)
|
||||
)
|
||||
next_index = int(max_index or 0) + 1 if max_index is not None else 0
|
||||
item = WorkoutItem(
|
||||
workout_id=workout.id,
|
||||
source_kind=source_kind,
|
||||
title_snapshot=title_snapshot,
|
||||
image_s3_url_snapshot=image_s3_url_snapshot,
|
||||
**payload.model_dump(exclude={"order_index"}),
|
||||
order_index=next_index,
|
||||
)
|
||||
@@ -297,18 +392,30 @@ def add_workout_set(
|
||||
select(WorkoutItem)
|
||||
.join(Workout)
|
||||
.where(WorkoutItem.id == item_id, Workout.user_id == user_id)
|
||||
.options(selectinload(WorkoutItem.sets), selectinload(WorkoutItem.exercise))
|
||||
.options(
|
||||
selectinload(WorkoutItem.sets),
|
||||
selectinload(WorkoutItem.exercise),
|
||||
selectinload(WorkoutItem.workout),
|
||||
)
|
||||
)
|
||||
if not item:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workout item not found")
|
||||
ensure_active_workout(item.workout)
|
||||
|
||||
calories = payload.calories
|
||||
if calories is None:
|
||||
calories = estimate_set_calories(item, payload)
|
||||
calories = estimate_set_calories(
|
||||
item,
|
||||
payload.weight,
|
||||
payload.reps,
|
||||
payload.duration_seconds,
|
||||
payload.calories,
|
||||
)
|
||||
max_index = db.scalar(
|
||||
select(func.max(WorkoutSet.set_index)).where(WorkoutSet.workout_item_id == item.id)
|
||||
)
|
||||
|
||||
workout_set = WorkoutSet(
|
||||
workout_item_id=item.id,
|
||||
set_index=len(item.sets) + 1,
|
||||
set_index=int(max_index or 0) + 1,
|
||||
weight=payload.weight,
|
||||
reps=payload.reps,
|
||||
duration_seconds=payload.duration_seconds,
|
||||
@@ -317,7 +424,112 @@ def add_workout_set(
|
||||
)
|
||||
db.add(workout_set)
|
||||
db.flush()
|
||||
recalculate_workout_calories(db, item.workout_id)
|
||||
recalculate_workout_totals(db, item.workout_id)
|
||||
db.commit()
|
||||
db.refresh(workout_set)
|
||||
return workout_set
|
||||
|
||||
|
||||
@app.post(
|
||||
"/internal/workout-items/{item_id}/sets/batch",
|
||||
dependencies=[InternalAuth],
|
||||
response_model=list[WorkoutSetRead],
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
def add_workout_sets_batch(
|
||||
item_id: uuid.UUID,
|
||||
payload: WorkoutSetBatchCreate,
|
||||
db: Db,
|
||||
user_id: CurrentUserId,
|
||||
) -> list[WorkoutSet]:
|
||||
item = db.scalar(
|
||||
select(WorkoutItem)
|
||||
.join(Workout)
|
||||
.where(WorkoutItem.id == item_id, Workout.user_id == user_id)
|
||||
.options(selectinload(WorkoutItem.exercise), selectinload(WorkoutItem.workout))
|
||||
)
|
||||
if not item:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workout item not found")
|
||||
ensure_active_workout(item.workout)
|
||||
|
||||
max_index = db.scalar(
|
||||
select(func.max(WorkoutSet.set_index)).where(WorkoutSet.workout_item_id == item.id)
|
||||
)
|
||||
next_index = int(max_index or 0) + 1
|
||||
workout_sets: list[WorkoutSet] = []
|
||||
for set_payload in payload.sets:
|
||||
workout_set = WorkoutSet(
|
||||
workout_item_id=item.id,
|
||||
set_index=next_index,
|
||||
weight=set_payload.weight,
|
||||
reps=set_payload.reps,
|
||||
duration_seconds=set_payload.duration_seconds,
|
||||
calories=estimate_set_calories(
|
||||
item,
|
||||
set_payload.weight,
|
||||
set_payload.reps,
|
||||
set_payload.duration_seconds,
|
||||
set_payload.calories,
|
||||
),
|
||||
completed_at=set_payload.completed_at or datetime.now(UTC),
|
||||
)
|
||||
db.add(workout_set)
|
||||
workout_sets.append(workout_set)
|
||||
next_index += 1
|
||||
|
||||
db.flush()
|
||||
recalculate_workout_totals(db, item.workout_id)
|
||||
db.commit()
|
||||
for workout_set in workout_sets:
|
||||
db.refresh(workout_set)
|
||||
return workout_sets
|
||||
|
||||
|
||||
@app.patch(
|
||||
"/internal/workout-sets/{set_id}",
|
||||
dependencies=[InternalAuth],
|
||||
response_model=WorkoutSetRead,
|
||||
)
|
||||
def update_workout_set(
|
||||
set_id: uuid.UUID,
|
||||
payload: WorkoutSetUpdate,
|
||||
db: Db,
|
||||
user_id: CurrentUserId,
|
||||
) -> WorkoutSet:
|
||||
workout_set = db.scalar(
|
||||
select(WorkoutSet)
|
||||
.join(WorkoutItem)
|
||||
.join(Workout)
|
||||
.where(WorkoutSet.id == set_id, Workout.user_id == user_id)
|
||||
.options(
|
||||
selectinload(WorkoutSet.workout_item).selectinload(WorkoutItem.exercise),
|
||||
selectinload(WorkoutSet.workout_item).selectinload(WorkoutItem.workout),
|
||||
)
|
||||
)
|
||||
if not workout_set:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Set not found")
|
||||
ensure_active_workout(workout_set.workout_item.workout)
|
||||
|
||||
if payload.weight is not None:
|
||||
workout_set.weight = payload.weight
|
||||
if payload.reps is not None:
|
||||
workout_set.reps = payload.reps
|
||||
if "duration_seconds" in payload.model_fields_set:
|
||||
workout_set.duration_seconds = payload.duration_seconds
|
||||
if "completed_at" in payload.model_fields_set and payload.completed_at is not None:
|
||||
workout_set.completed_at = payload.completed_at
|
||||
if "calories" in payload.model_fields_set:
|
||||
workout_set.calories = payload.calories
|
||||
else:
|
||||
workout_set.calories = estimate_set_calories(
|
||||
workout_set.workout_item,
|
||||
float(workout_set.weight or 0),
|
||||
int(workout_set.reps or 0),
|
||||
workout_set.duration_seconds,
|
||||
None,
|
||||
)
|
||||
|
||||
recalculate_workout_totals(db, workout_set.workout_item.workout_id)
|
||||
db.commit()
|
||||
db.refresh(workout_set)
|
||||
return workout_set
|
||||
@@ -337,12 +549,12 @@ def delete_workout_item(item_id: uuid.UUID, db: Db, user_id: CurrentUserId) -> N
|
||||
if not item:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workout item not found")
|
||||
workout = db.get(Workout, item.workout_id)
|
||||
if workout and workout.finished_at:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Workout already finished")
|
||||
if workout:
|
||||
ensure_active_workout(workout)
|
||||
workout_id = item.workout_id
|
||||
db.delete(item)
|
||||
db.flush()
|
||||
recalculate_workout_calories(db, workout_id)
|
||||
recalculate_workout_totals(db, workout_id)
|
||||
db.commit()
|
||||
|
||||
|
||||
@@ -369,33 +581,64 @@ def delete_workout_set(
|
||||
workout = db.scalar(
|
||||
select(Workout).join(WorkoutItem).where(WorkoutItem.id == item_id)
|
||||
)
|
||||
if workout and workout.finished_at:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Workout already finished")
|
||||
if workout:
|
||||
ensure_active_workout(workout)
|
||||
db.delete(ws)
|
||||
db.flush()
|
||||
reindex_workout_item_sets(db, item_id)
|
||||
if workout:
|
||||
recalculate_workout_calories(db, workout.id)
|
||||
recalculate_workout_totals(db, workout.id)
|
||||
db.commit()
|
||||
|
||||
|
||||
def estimate_set_calories(item: WorkoutItem, payload: WorkoutSetCreate) -> float:
|
||||
if item.exercise and item.exercise.default_calories_per_minute and payload.duration_seconds:
|
||||
def reindex_workout_item_sets(db: Session, item_id: uuid.UUID) -> None:
|
||||
remaining_sets = list(
|
||||
db.scalars(
|
||||
select(WorkoutSet)
|
||||
.where(WorkoutSet.workout_item_id == item_id)
|
||||
.order_by(WorkoutSet.set_index.asc(), WorkoutSet.completed_at.asc())
|
||||
)
|
||||
)
|
||||
for index, workout_set in enumerate(remaining_sets, start=1):
|
||||
workout_set.set_index = index
|
||||
|
||||
|
||||
def estimate_set_calories(
|
||||
item: WorkoutItem,
|
||||
weight: float,
|
||||
reps: int,
|
||||
duration_seconds: int | None,
|
||||
calories: float | None,
|
||||
) -> float:
|
||||
if calories is not None:
|
||||
return calories
|
||||
if item.exercise and item.exercise.default_calories_per_minute and duration_seconds:
|
||||
return round(
|
||||
float(item.exercise.default_calories_per_minute) * payload.duration_seconds / 60,
|
||||
float(item.exercise.default_calories_per_minute) * duration_seconds / 60,
|
||||
2,
|
||||
)
|
||||
return round((payload.weight * max(payload.reps, 1)) / 120, 2)
|
||||
return round((weight * max(reps, 1)) / 120, 2)
|
||||
|
||||
|
||||
def recalculate_workout_totals(db: Session, workout_id: uuid.UUID) -> None:
|
||||
total_sets, total_volume, estimated_calories = db.execute(
|
||||
select(
|
||||
func.count(WorkoutSet.id),
|
||||
func.coalesce(func.sum(WorkoutSet.weight * WorkoutSet.reps), 0),
|
||||
func.coalesce(func.sum(WorkoutSet.calories), 0),
|
||||
)
|
||||
.join(WorkoutItem, WorkoutSet.workout_item_id == WorkoutItem.id)
|
||||
.where(WorkoutItem.workout_id == workout_id)
|
||||
).one()
|
||||
workout = db.get(Workout, workout_id)
|
||||
if workout:
|
||||
workout.total_sets = int(total_sets or 0)
|
||||
workout.total_volume = float(total_volume or 0)
|
||||
workout.estimated_calories = float(estimated_calories or 0)
|
||||
|
||||
|
||||
def recalculate_workout_calories(db: Session, workout_id: uuid.UUID) -> None:
|
||||
total = db.scalar(
|
||||
select(func.coalesce(func.sum(WorkoutSet.calories), 0))
|
||||
.join(WorkoutItem, WorkoutSet.workout_item_id == WorkoutItem.id)
|
||||
.where(WorkoutItem.workout_id == workout_id)
|
||||
)
|
||||
workout = db.get(Workout, workout_id)
|
||||
if workout:
|
||||
workout.estimated_calories = float(total or 0)
|
||||
recalculate_workout_totals(db, workout_id)
|
||||
|
||||
|
||||
@app.get(
|
||||
@@ -413,7 +656,7 @@ def get_progression(
|
||||
select(Workout.started_at, WorkoutSet.weight, WorkoutSet.reps)
|
||||
.join(WorkoutItem, WorkoutSet.workout_item_id == WorkoutItem.id)
|
||||
.join(Workout, WorkoutItem.workout_id == Workout.id)
|
||||
.where(Workout.user_id == user_id)
|
||||
.where(Workout.user_id == user_id, Workout.status != "discarded")
|
||||
.order_by(Workout.started_at.asc(), WorkoutSet.completed_at.asc())
|
||||
)
|
||||
if entity_id and kind == "exercise":
|
||||
@@ -450,7 +693,9 @@ def get_progression(
|
||||
def get_calories(db: Db, user_id: CurrentUserId) -> CaloriesRead:
|
||||
workouts = list(
|
||||
db.scalars(
|
||||
select(Workout).where(Workout.user_id == user_id).order_by(Workout.started_at.desc())
|
||||
select(Workout)
|
||||
.where(Workout.user_id == user_id, Workout.status != "discarded")
|
||||
.order_by(Workout.started_at.desc())
|
||||
)
|
||||
)
|
||||
total = sum(float(workout.estimated_calories or 0) for workout in workouts)
|
||||
|
||||
Reference in New Issue
Block a user