wings/internal/ufs/quota_writer.go

118 lines
2.6 KiB
Go
Raw Permalink Normal View History

// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: Copyright (c) 2024 Matthew Penner
package ufs
import (
"errors"
"io"
"sync/atomic"
)
// CountedWriter is a writer that counts the amount of data written to the
// underlying writer.
type CountedWriter struct {
File
counter atomic.Int64
err error
}
// NewCountedWriter returns a new countedWriter that counts the amount of bytes
// written to the underlying writer.
func NewCountedWriter(f File) *CountedWriter {
return &CountedWriter{File: f}
}
// BytesWritten returns the amount of bytes that have been written to the
// underlying writer.
func (w *CountedWriter) BytesWritten() int64 {
return w.counter.Load()
}
// Error returns the error from the writer if any. If the error is an EOF, nil
// will be returned.
func (w *CountedWriter) Error() error {
if errors.Is(w.err, io.EOF) {
return nil
}
return w.err
}
// Write writes bytes to the underlying writer while tracking the total amount
// of bytes written.
func (w *CountedWriter) Write(p []byte) (int, error) {
if w.err != nil {
return 0, io.EOF
}
// Write is a very simple operation for us to handle.
n, err := w.File.Write(p)
w.counter.Add(int64(n))
w.err = err
// TODO: is this how we actually want to handle errors with this?
if err == io.EOF {
return n, io.EOF
} else {
return n, nil
}
}
func (w *CountedWriter) ReadFrom(r io.Reader) (n int64, err error) {
cr := NewCountedReader(r)
n, err = w.File.ReadFrom(cr)
w.counter.Add(n)
return
}
// CountedReader is a reader that counts the amount of data read from the
// underlying reader.
type CountedReader struct {
reader io.Reader
counter atomic.Int64
err error
}
var _ io.Reader = (*CountedReader)(nil)
// NewCountedReader returns a new countedReader that counts the amount of bytes
// read from the underlying reader.
func NewCountedReader(r io.Reader) *CountedReader {
return &CountedReader{reader: r}
}
// BytesRead returns the amount of bytes that have been read from the underlying
// reader.
func (r *CountedReader) BytesRead() int64 {
return r.counter.Load()
}
// Error returns the error from the reader if any. If the error is an EOF, nil
// will be returned.
func (r *CountedReader) Error() error {
if errors.Is(r.err, io.EOF) {
return nil
}
return r.err
}
// Read reads bytes from the underlying reader while tracking the total amount
// of bytes read.
func (r *CountedReader) Read(p []byte) (int, error) {
if r.err != nil {
return 0, io.EOF
}
n, err := r.reader.Read(p)
r.counter.Add(int64(n))
r.err = err
if err == io.EOF {
return n, io.EOF
} else {
return n, nil
}
}