#include <iostream>
#include "misc.h"
#include "sudoku.h"

using namespace std;

/**
Prints a string representation of a unit type.  Used with verbose output.
*/
ostream &operator<<(ostream &o, Sudoku::Unit::Type t)
{
	if(t==Sudoku::Unit::ROW)
		o<<"row";
	else if(t==Sudoku::Unit::COL)
		o<<"col";
	else if(t==Sudoku::Unit::BOX)
		o<<"box";
	else
		o<<t;

	return o;
}

/**
Creates a new sudoku puzzle.  Initially all cells are unset.
*/
Sudoku::Sudoku():
	verbose(false),
	solved(false)
{
	// Create units
	for(unsigned i=0; i<9; ++i)
	{
		rows.push_back(Unit(*this, Unit::ROW, i+1));
		cols.push_back(Unit(*this, Unit::COL, i+1));
		boxes.push_back(Unit(*this, Unit::BOX, i+1));
	}

	// Create cells
	for(unsigned y=0; y<9; ++y) for(unsigned x=0; x<9; ++x)
		cells.push_back(Cell(x+1, y+1));

	// Associate cells with units
	for(unsigned y=0; y<9; ++y) for(unsigned x=0; x<9; ++x)
		cells[x+y*9].set_units(rows[y], cols[x], boxes[(x/3)+(y/3)*3]);

	// Create a vector of all units
	for(unsigned i=0; i<9; ++i)
	{
		units.push_back(&rows[i]);
		units.push_back(&cols[i]);
		units.push_back(&boxes[i]);
	}
}

/**
Sets the cell at the given coordinates to the given value.
*/
void Sudoku::set_cell(unsigned x, unsigned y, Number n)
{
	if(x>=9 && y>=9) return;
	if(n>9) return;

	Cell	&c=cell(x,y);
	c=n;
	if(!n)
	{
		c.set_fixed(false);
		clear_soft_masks();
	}
}

/**
Returns a pointer to the cell at the given coordinates, or 0 if the coordinates
were out of range.
*/
const Sudoku::Cell *Sudoku::get_cell(unsigned x, unsigned y) const
{
	if(x>=9 && y>=9) return 0;

	return &cell(x,y);
}

/**
Tries to solve the puzzle.  First solve as many single cells as possible, then
try to eliminate some choices to solve more cells.  Repeat as long as progress
is made.

@param   step  If true, the solver loop is only iterated once.  Multiple cells
               may still get solved.
*/
void Sudoku::solve(bool step)
{
	bool allow_bf=false;
	while(1)
	{
		bool	found=false;

		// Don't even try if the puzzle isn't in a solvable state.
		if(detect_impossibilities())
			break;

		// See if we can find any cells with just one possible number
		for(unsigned i=0; i<9*9; ++i)
		{
			if(solve_cell(cells[i]))
				found=true;
		}

		// Allow bruteforcing after we've solved at least some cells
		if(found) allow_bf=true;

		// No cells could be solved, try to eliminate some choices
		if(!found && eliminate())
			found=true;

		// Still no luck?  Try bruteforcing.
		if(!found && allow_bf && brute_force())
			found=true;

		// Argh!  We're stuck!  Give up.
		if(!found || step) break;
	}

	// Check whether there are any unsolved cells.
	solved=true;
	for(unsigned i=0; i<9*9; ++i)
		if(!cells[i].get_num())
		{
			solved=false;
			break;
		}
}

/**
Attempts to eliminate candidates by employing various algorithms.
*/
int Sudoku::eliminate()
{
	bool found=false;

	for(unsigned i=0; i<27; ++i)
		if(units[i]->check_uniques())
			found=true;
	if(found) return 1;

	if(reduce_intersections()>0)
		return 1;

	for(unsigned i=0; i<27; ++i)
		if(units[i]->find_naked_tuples())
			found=true;
	if(found) return 1;

	for(unsigned i=0; i<27; ++i)
		if(units[i]->find_hidden_tuples())
			found=true;

	if(!found && xwing())
		return 1;

	return found;
}

/**
Loads a puzzle from a string.  The string must be exactly 81 characters long,
and each character must be in the range '0'-'9', '0' being unset.
*/
void Sudoku::load_from_string(const string &str)
{
	if(str.size()!=9*9) return;

	for(unsigned i=0; i<9*9; ++i)
	{
		cells[i]=str[i]-'0';
		if(cells[i].get_num())
			cells[i].set_fixed(true);
	}
}

/**
Fixes all cells that currently have a number.
*/
void Sudoku::fix_cells()
{
	for(unsigned i=0; i<9*9; ++i)
		if(cells[i].get_num())
			cells[i].set_fixed(true);
}

/**
Clears any non-fixed cells, resetting the puzzle to the initial state.
*/
void Sudoku::clear_nonfixed()
{
	for(unsigned i=0; i<9*9; ++i)
		if(!cells[i].get_fixed())
			cells[i]=0;
	clear_soft_masks();
}

/*** private ***/

/**
Tries to solve a single cell.

@return  Whether a solution was found
*/
int Sudoku::solve_cell(Cell &c)
{
	if(c.get_num()) return 0;

	// We can only solve a cell if there's only one possible number for it
	if(bitcount(c.get_mask())!=1) return 0;

	// Assign the only solution
	c=bitindex(c.get_mask());

	if(verbose) cerr<<"Solved cell "<<c.get_x()<<','<<c.get_y()<<" with "<<c.get_num()<<'\n';

	return 1;
}

/**
Clears the soft masks of all cells.  This must be called if a cell is cleared
or overwritten.
*/
void Sudoku::clear_soft_masks()
{
	for(unsigned i=0; i<9*9; ++i)
		cells[i].clear_soft_mask();
}

/**
Goes through all possible combinations of units that may have intersections and
reduces them.

@return  Whether any intersections were found.
*/
int Sudoku::reduce_intersections()
{
	bool found=false;

	for(unsigned y=0; y<9; ++y)
		for(unsigned x=0; x<3; ++x)
		{
			if(rows[y].reduce_intersections(boxes[y/3*3+x])>0)
				found=true;
			if(boxes[y/3*3+x].reduce_intersections(rows[y]))
				found=true;
		}

	for(unsigned x=0; x<9; ++x)
		for(unsigned y=0; y<3; ++y)
		{
			if(cols[x].reduce_intersections(boxes[x%3+y*3]))
				found=true;
			if(boxes[x%3+y*3].reduce_intersections(cols[x]))
				found=true;
		}

	return found;
}

/**
Employs the X-Wing tactic on all possible combinations of rows and columns.
*/
int Sudoku::xwing()
{
	bool found=false;
	for(unsigned i=0; i<8; ++i)
		for(unsigned j=i+1; j<9; ++j)
		{
			if(rows[i].xwing(rows[j]))
				found=true;
			if(cols[i].xwing(cols[j]))
				found=true;
		}

	return found;
}

/**
Tries to guess the value for one cell, hoping that it will lead to a solution.
*/
int Sudoku::brute_force()
{
	vector<Cell> backup=cells;

	for(unsigned i=0; (!solved && i<9*9); ++i)
	{
		if(cells[i].get_num()) continue;
		for(Number j=1; (!solved && j<=9); ++j)
		{
			if(!cells[i].get_mask(j)) continue;

			if(verbose)
				cerr<<"Bruteforcing cell "<<i%9+1<<','<<i/9+1<<" with "<<j<<'\n';

			cells[i]=j;
			solve();
			if(!solved)
			{
				// Must restore cells one by one to keep unit masks in sync
				for(unsigned k=0; k<9*9; ++k)
					cells[k]=backup[k].get_num();
				if(verbose)
					cerr<<"Bruteforcing failed\n";
			}
		}
	}

	return 0;
}

/**
Checks if there are any cells with no possible solutions.
*/
int Sudoku::detect_impossibilities()
{
	for(unsigned i=0; i<9*9; ++i)
		if(!cells[i].get_mask() && !cells[i].get_num())
		{
			if(verbose)
				cerr<<"Cell "<<i%9+1<<','<<i/9+1<<" doesn't have any solutions\n";
			return 1;
		}

	return 0;
}

/*******************
** Sudoku::Cell
*/

Sudoku::Cell::Cell(unsigned short x_, unsigned short y_):
	x(x_),
	y(y_),
	num(0),
	mask(0x3FE),
	soft_mask(0x3FE),
	fixed(false),
	row(0),
	col(0),
	box(0)
{ }

/**
Associates the cell with the given units.  Can't be done in constructor, since
we copy the cells around a bit when bruteforcing.
*/
void Sudoku::Cell::set_units(Unit &r, Unit &c, Unit &b)
{
	row=&r;
	col=&c;
	box=&b;
	row->add_cell(*this);
	col->add_cell(*this);
	box->add_cell(*this);
}

/**
Computes the cell's hard mask from those of the units it belongs to.
*/
void Sudoku::Cell::update_mask()
{
	mask=row->get_mask()&col->get_mask()&box->get_mask();
}

/**
Sets the cell to the given number, if possible.
*/
Sudoku::Cell &Sudoku::Cell::operator=(Number n)
{
	// Do nothing in various error conditions
	if(n>9) return *this;
	if(n && !(mask&(1<<n))) return *this;

	// If there's a number currently, mark it as available in the units
	if(num)
	{
		row->unmask_num(num);
		col->unmask_num(num);
		box->unmask_num(num);
	}

	num=n;

	// The new number is no longer available
	if(num)
	{
		row->mask_num(num);
		col->mask_num(num);
		box->mask_num(num);
	}

	return *this;
}

/*******************
** Sudoku::Unit
*/

/**
Counts the possible locations for the given number.  Same as
bitcount(get_cell_mask(n)) but a bit more efficient.
*/
unsigned Sudoku::Unit::get_count(Number n) const
{
	unsigned	count=0;
	for(vector<Cell *>::const_iterator i=cells.begin(); i!=cells.end(); ++i)
		if(!(*i)->get_num() && (*i)->get_mask(n))
			++count;
	return count;
}

/**
Returns the mask of cells where the given number could be placed.
*/
Sudoku::Mask Sudoku::Unit::get_cell_mask(Number n) const
{
	Mask result=0;
	for(unsigned i=0; i<9; ++i)
		if(cells[i]->get_mask(n) && !cells[i]->get_num())
			result|=1<<i;
	return result;
}

/**
Marks a number as impossible in this unit, probably because it has already been
used.
*/
void Sudoku::Unit::mask_num(Number n)
{
	mask&=~(1<<n);
	update_cells();
}

/**
Marks a number as "free" in this unit.
*/
void Sudoku::Unit::unmask_num(Number n)
{
	mask|=1<<n;
	update_cells();
}

/**
Tells all the cells in the unit to update their masks.
*/
void Sudoku::Unit::update_cells()
{
	for(vector<Cell *>::iterator i=cells.begin(); i!=cells.end(); ++i)
		(*i)->update_mask();
}

/**
Checks for numbers in this unit that have only one possible solution.  As a
result, those cells can't have any other values or the puzzle would not be
solvable.

@return  Whether any numbers had unique solutions
*/
int Sudoku::Unit::check_uniques()
{
	if(!mask) return 0;

	// Create a mask of all number that can only be placed at one position
	Mask umask=0;
	for(Number i=1; i<=9; ++i)
		if(get_count(i)==1)
			umask|=1<<i;

	// No such numbers found, we can't do anything
	if(!umask) return 0;

	// Mask out all other numbers from the relevant cells
	for(vector<Cell *>::const_iterator i=cells.begin(); i!=cells.end(); ++i)
	{
		if((*i)->get_num()) continue;

		Mask m=(*i)->get_mask()&umask;
		if(m && (*i)->get_mask()!=m)
		{
			(*i)->mask_nums(~m);
			if(sudoku->verbose) cerr<<bitindex(mask)<<" Has a unique solution in "<<type<<' '<<index<<'\n';
		}
	}

	return 1;
}

/**
If all solutions for a certain number in the other unit are on this unit, then
that number can be eliminated for all cells on this unit that are not part of
the other unit.

@return  Whether there was an intersection
*/
int Sudoku::Unit::reduce_intersections(const Unit &other)
{
	bool found=false;

	for(Number i=1; i<=9; ++i)
	{
		/* Check for possible places for the number in the other cell that are
		outside this one */
		bool ok=true;
		for(vector<Cell *>::const_iterator j=other.get_cells().begin(); (j!=other.get_cells().end() && ok); ++j)
			if(((*j)->get_mask(i) && !(*j)->part_of(this) && !(*j)->get_num()) || (*j)->get_num()==i)
				ok=false;

		if(ok)
		{
			// No possibilities outside this unit.  We have an intersection!

			/* Now mask the number out from all local cells that are not part of
			the other unit. */
			bool found_this=false;
			for(vector<Cell *>::iterator j=cells.begin(); j!=cells.end(); ++j)
				if(!(*j)->part_of(&other) && (*j)->get_mask(i))
				{
					(*j)->mask_num(i);
					found=found_this=true;
				}

			if(found_this && sudoku->verbose)
				cerr<<"Intersection of "<<i<<" in "<<type<<' '<<index<<" and "<<other.get_type()<<' '<<other.get_index()<<" reduced\n";
		}
	}

	return found;
}

/**
Finds groups of N>1 cells that can only have the same N numbers.  The rest of
the unit can't have those numbers.

@return  Whether any such tuples were found.
*/
int Sudoku::Unit::find_naked_tuples()
{
	bool found=false;

	/* Take copies of cell masks, because the originals get modified along the
	way.  Also make sure the masks for any used cells are 0. */
	vector<Mask> masks(9,0);
	for(unsigned i=0; i<9; ++i)
		if(!cells[i]->get_num())
			masks[i]=cells[i]->get_mask();

	for(unsigned i=0; i<8; ++i)
	{
		if(!masks[i]) continue;

		// Count the equal masks
		Mask eqmask=1<<i;
		for(unsigned j=i+1; j<9; ++j)
			if(masks[j]==masks[i])
				eqmask|=1<<j;

		// We need to have as many equal masks as there are bits in them
		if(bitcount(eqmask)!=bitcount(masks[i]) || bitcount(eqmask)==1) continue;

		bool found_this=false;
		// Mask the numbers from all other cells
		for(unsigned k=0; k<9; ++k)
			if(!(eqmask&(1<<k)) && (cells[k]->get_mask()&masks[i]) && !(cells[k]->get_num()))
			{
				cells[k]->mask_nums(masks[i]);
				found=found_this=true;
			}

		if(found_this && sudoku->verbose)
		{
			cerr<<"Found naked tuple of ";
			print_bits(cerr,masks[i]);
			cerr<<" in "<<type<<' '<<index<<'\n';
		}

		// Zero out the masks to mark them processed
		for(unsigned j=i; j<9; ++j)
			if(eqmask&(1<<j))
				masks[j]=0;
	}

	return found;
}

/**
Finds groups of N>1 numbers that can only exist in the same N cells.  As a
consequence, those cells can't have any other numbers.
*/
int Sudoku::Unit::find_hidden_tuples()
{
	bool found=false;

	// Create cell masks for each number
	vector<Mask> masks(9,0);
	for(Number i=1; i<=9; ++i)
		masks[i-1]=get_cell_mask(i);

	for(unsigned i=0; i<8; ++i)
	{
		if(!masks[i]) continue;

		// Count the equal masks
		Mask eqmask=2<<i;
		for(unsigned j=i+1; j<9; ++j)
			if(masks[j]==masks[i])
				eqmask|=2<<j;

		// We need to have as many equal masks as there are bits in them
		if(bitcount(eqmask)!=bitcount(masks[i]) || bitcount(eqmask)==1) continue;

		bool found_this=false;
		// Mask all other numbers from the set of cells
		for(unsigned j=0; j<9; ++j)
			if(masks[i]&(1<<j) && (cells[j]->get_mask()&~eqmask))
			{
				cells[j]->mask_nums(~eqmask);
				found=found_this=true;
			}

		if(found_this && sudoku->verbose)
		{
			cerr<<"Found hidden tuple of ";
			print_bits(cerr,eqmask);
			cerr<<" in "<<type<<' '<<index<<'\n';
		}

		// Zero out the masks to mark them processed
		for(unsigned j=i; j<9; ++j)
			if(eqmask&(2<<j))
				masks[j]=0;
	}

	return found;
}

/**
If this cell and the other one have exactly two solutions for some number and
in the same indices, that number may be eliminated from other cells in the units
crossing at those indices.

@return  Whether anything was eliminated
*/
int Sudoku::Unit::xwing(const Unit &other)
{
	// Sanity check
	if((type!=ROW && type!=COL) || type!=other.get_type()) return 0;

	bool found=false;

	for(Number i=1; i<=9; ++i)
	{
		Mask mask1=get_cell_mask(i);
		Mask mask2=other.get_cell_mask(i);
		if(mask1!=mask2 || bitcount(mask1)!=2)
			continue;

		// We found an X-Wing, see if we can shoot anything down
		bool found_this=false;
		for(unsigned j=0; j<9; ++j)
			if(mask1&(1<<j))
			{
				// Get a reference to the crossing unit and its cells
				Unit &third=(type==ROW) ? cells[j]->get_column() : cells[j]->get_row();
				const vector<Cell *> &tcells=third.get_cells();

				/* Eliminate the number from all cells that are not part of either
				of the focus units */
				for(unsigned k=0; k<9; ++k)
					if(!tcells[k]->part_of(this) && !tcells[k]->part_of(&other) && tcells[k]->get_mask(i) && !tcells[k]->get_num())
					{
						tcells[k]->mask_num(i);
						found=found_this=true;
					}
			}

		if(found_this && sudoku->verbose)
			cerr<<"Found x-wing of "<<i<<" on "<<type<<' '<<index<<','<<other.get_index()<<'\n';
	}

	return found;
}
