Decryptor: Require end of the output span to be indicated explicitly.

Usage becomes much less error-prone.
This commit is contained in:
hashlag
2026-02-02 01:10:22 +03:00
parent 67709e5361
commit a0f17ea127
3 changed files with 21 additions and 47 deletions

View File

@@ -9,9 +9,10 @@ class Decryptor
{
public:
template<typename OutputIt, typename InputIt>
void DecryptBlock(OutputIt out, InputIt inBegin, InputIt inEnd) const
void DecryptBlock(OutputIt outBegin, OutputIt outEnd,
InputIt inBegin, InputIt inEnd) const
{
Impl().DecryptBlock(out, inBegin, inEnd);
Impl().DecryptBlock(outBegin, outEnd, inBegin, inEnd);
}
template<typename Block>

View File

@@ -290,7 +290,8 @@ public:
{ }
template<typename OutputIt, typename InputIt>
void DecryptBlock(OutputIt out, InputIt inBegin, InputIt inEnd) const
void DecryptBlock(OutputIt outBegin, OutputIt outEnd,
InputIt inBegin, InputIt inEnd) const
{
RawBlockArray block;
@@ -305,7 +306,7 @@ public:
block.End()),
Schedule_);
Inner_::Bitwise::CrunchUInt64(out, decrypted);
Inner_::Bitwise::CrunchUInt64(outBegin, outEnd, decrypted);
}
Block DecryptBlock(Block block) const

View File

@@ -211,7 +211,7 @@ TEST(DesCryptTests, DecryptTest)
DesCrypt::Key desKey(key.begin(), key.end());
DesCrypt::DesDecryptor dec(desKey);
dec.DecryptBlock(result.begin(), data.begin(), data.end());
dec.DecryptBlock(result.begin(), result.end(), data.begin(), data.end());
return result;
}
@@ -303,7 +303,7 @@ TEST(DesCryptTests, DecryptShortDataTest)
DesCrypt::Key desKey(key.begin(), key.end());
DesCrypt::DesDecryptor dec(desKey);
dec.DecryptBlock(result.begin(), data.begin(), data.end());
dec.DecryptBlock(result.begin(), result.end(), data.begin(), data.end());
return result;
}
@@ -337,7 +337,7 @@ TEST(DesCryptTests, DecryptLongDataTest)
DesCrypt::Key desKey(key.begin(), key.end());
DesCrypt::DesDecryptor dec(desKey);
dec.DecryptBlock(result.begin(), data.begin(), data.end());
dec.DecryptBlock(result.begin(), result.end(), data.begin(), data.end());
return result;
}
@@ -410,62 +410,34 @@ TEST(DesCryptTests, OutIteratorUsageEncryptTest)
TEST(DesCryptTests, OutIteratorUsageDecryptTest)
{
struct OutputItMock
{
OutputItMock(size_t & asteriskCalls, size_t & incrementCalls)
: AsteriskCalls_(asteriskCalls)
, IncrementCalls_(incrementCalls)
{ }
uint8_t & operator*()
{
++AsteriskCalls_;
static uint8_t dummy = 0;
return dummy;
}
OutputItMock operator++(int)
{
++IncrementCalls_;
return *this;
}
size_t & AsteriskCalls_;
size_t & IncrementCalls_;
};
{
std::array<uint8_t, DesCrypt::BlockSize> data = { 0xe5, 0x1a, 0x9f, 0xd4, 0x19, 0xa7, 0x93, 0x44 };
std::array<uint8_t, 8> key = { 0xda, 0xec, 0x68, 0xae, 0x83, 0xe0, 0x1e, 0xab };
size_t asteriskCalls = 0;
size_t incrementCalls = 0;
OutputItMock it(asteriskCalls, incrementCalls);
std::array<uint8_t, 8> fact = {};
// Last 3 bytes should be untouched.
std::array<uint8_t, 8> expected = { 0x45, 0x69, 0x71, 0x17, 0x13, 0x00, 0x00, 0x00 };
DesCrypt::Key desKey(key.begin(), key.end());
DesCrypt::DesDecryptor dec(desKey);
dec.DecryptBlock(it, data.begin(), data.end());
dec.DecryptBlock(fact.begin(), fact.end() - 3, data.begin(), data.end());
ASSERT_EQ(8, asteriskCalls);
ASSERT_EQ(8, incrementCalls);
ASSERT_EQ(expected, fact);
}
{
std::array<uint8_t, 11> data = { 0xe5, 0x1a, 0x9f, 0xd4, 0x19, 0x9f, 0x9f, 0x9f, 0x9f, 0x9f, 0x9f };
std::array<uint8_t, DesCrypt::BlockSize + 2> data = { 0xe5, 0x1a, 0x9f, 0xd4, 0x19, 0xa7, 0x93, 0x44, 0x44, 0x44 };
std::array<uint8_t, 8> key = { 0xda, 0xec, 0x68, 0xae, 0x83, 0xe0, 0x1e, 0xab };
size_t asteriskCalls = 0;
size_t incrementCalls = 0;
OutputItMock it(asteriskCalls, incrementCalls);
std::array<uint8_t, 8> fact = {};
// Last 4 bytes should be untouched.
std::array<uint8_t, 8> expected = { 0x45, 0x69, 0x71, 0x17, 0x00, 0x00, 0x00, 0x00 };
DesCrypt::Key desKey(key.begin(), key.end());
DesCrypt::DesDecryptor dec(desKey);
dec.DecryptBlock(it, data.begin(), data.end());
dec.DecryptBlock(fact.begin(), fact.end() - 4, data.begin(), data.end());
ASSERT_EQ(8, asteriskCalls);
ASSERT_EQ(8, incrementCalls);
ASSERT_EQ(expected, fact);
}
}
@@ -519,7 +491,7 @@ static std::vector<uint8_t> DecryptThroughBase(const Decryptor<Impl> & dec,
std::vector<uint8_t> result;
result.resize(dec.GetBlockSize(), 0);
dec.DecryptBlock(result.begin(), begin, end);
dec.DecryptBlock(result.begin(), result.end(), begin, end);
return result;
}